fix: add the batch extractor + handler

Closes #105, #86
This commit is contained in:
Philip Jenvey 2018-11-30 14:47:52 -08:00
parent d392e95e06
commit 1d3a8ae3ad
No known key found for this signature in database
GPG Key ID: 5B9F83DE4F7EB7FA
8 changed files with 325 additions and 58 deletions

View File

@ -75,7 +75,7 @@ impl Db for MockDb {
mock_db_method!(validate_batch, ValidateBatch);
mock_db_method!(append_to_batch, AppendToBatch);
mock_db_method!(get_batch, GetBatch, Option<results::GetBatch>);
mock_db_method!(delete_batch, DeleteBatch);
mock_db_method!(commit_batch, CommitBatch);
}
unsafe impl Send for MockDb {}

View File

@ -119,7 +119,7 @@ pub trait Db: Send + Debug {
fn get_batch(&self, params: params::GetBatch) -> DbFuture<Option<results::GetBatch>>;
fn delete_batch(&self, params: params::DeleteBatch) -> DbFuture<results::DeleteBatch>;
fn commit_batch(&self, params: params::CommitBatch) -> DbFuture<results::CommitBatch>;
fn box_clone(&self) -> Box<dyn Db>;

View File

@ -76,7 +76,7 @@ pub fn get(db: &MysqlDb, params: params::GetBatch) -> Result<Option<results::Get
.optional()?)
}
pub fn delete(db: &MysqlDb, params: params::DeleteBatch) -> Result<()> {
fn delete(db: &MysqlDb, params: params::DeleteBatch) -> Result<()> {
let user_id = params.user_id.legacy_id as i32;
let collection_id = db.get_collection_id(&params.collection)?;
diesel::delete(batches::table)

View File

@ -693,7 +693,6 @@ impl MysqlDb {
batch_db_method!(create_batch_sync, create, CreateBatch);
batch_db_method!(validate_batch_sync, validate, ValidateBatch);
batch_db_method!(append_to_batch_sync, append, AppendToBatch);
batch_db_method!(delete_batch_sync, delete, DeleteBatch);
batch_db_method!(commit_batch_sync, commit, CommitBatch);
pub fn get_batch_sync(&self, params: params::GetBatch) -> Result<Option<results::GetBatch>> {
@ -803,7 +802,7 @@ impl Db for MysqlDb {
GetBatch,
Option<results::GetBatch>
);
sync_db_method!(delete_batch, delete_batch_sync, DeleteBatch);
sync_db_method!(commit_batch, commit_batch_sync, CommitBatch);
}
#[derive(Debug, QueryableByName)]

View File

@ -85,7 +85,7 @@ impl ApiError {
false
}
fn is_conflict(&self) -> bool {
pub fn is_conflict(&self) -> bool {
match self.kind() {
ApiErrorKind::Db(dbe) => match dbe.kind() {
DbErrorKind::Conflict => return true,
@ -99,7 +99,10 @@ impl ApiError {
fn weave_error_code(&self) -> WeaveError {
match self.kind() {
ApiErrorKind::Validation(ver) => match ver.kind() {
ValidationErrorKind::FromDetails(ref _description, ref location, name) => {
ValidationErrorKind::FromDetails(ref description, ref location, name) => {
if description == "size-limit-exceeded" {
return WeaveError::SizeLimitExceeded;
}
let name = name.clone().unwrap_or("".to_owned());
if *location == RequestErrorLocation::Body
&& ["bso", "bsos"].contains(&name.as_str())

View File

@ -12,7 +12,6 @@ extern crate diesel_logger;
extern crate diesel_migrations;
extern crate docopt;
extern crate env_logger;
#[macro_use]
extern crate failure;
extern crate futures;
#[macro_use]

View File

@ -2,8 +2,7 @@
//!
//! Handles ensuring the header's, body, and query parameters are correct, extraction to
//! relevant types, and failing correctly with the appropriate errors if issues arise.
use std::collections::HashMap;
use std::str::FromStr;
use std::{self, collections::HashMap, str::FromStr};
use actix_web::http::header::{HeaderValue, ACCEPT, CONTENT_TYPE};
use actix_web::{
@ -17,7 +16,7 @@ use serde::de::{Deserialize, Deserializer, Error as SerdeError};
use serde_json::Value;
use validator::{Validate, ValidationError};
use db::{util::SyncTimestamp, Db, Sorting};
use db::{util::SyncTimestamp, Db, DbError, DbErrorKind, Sorting};
use error::{ApiError, ApiResult};
use server::ServerState;
use web::{auth::HawkPayload, error::ValidationErrorKind};
@ -34,6 +33,7 @@ lazy_static! {
Regex::new(r#"IV":\s*"AAAAAAAAAAAAAAAAAAAAAA=="#).unwrap();
static ref VALID_ID_REGEX: Regex = Regex::new(r"^[ -~]{1,64}$").unwrap();
static ref VALID_COLLECTION_ID_REGEX: Regex = Regex::new(r"^[a-zA-Z0-9._-]{1,32}$").unwrap();
static ref TRUE_REGEX: Regex = Regex::new("^(?i)true$").unwrap();
}
#[derive(Deserialize)]
@ -450,6 +450,7 @@ pub struct CollectionPostRequest {
pub user_id: HawkIdentifier,
pub query: BsoQueryParams,
pub bsos: BsoBodies,
pub batch: Option<BatchRequest>,
}
impl FromRequest<ServerState> for CollectionPostRequest {
@ -463,6 +464,7 @@ impl FromRequest<ServerState> for CollectionPostRequest {
/// - If the collection is 'crypto', known bad payloads are checked for
/// - Any valid BSO's beyond `BATCH_MAX_RECORDS` are moved to invalid
fn from_request(req: &HttpRequest<ServerState>, _: &Self::Config) -> Self::Result {
let req = req.clone();
let max_post_records = req.state().limits.max_post_records as i64;
let fut = <(
HawkIdentifier,
@ -470,7 +472,7 @@ impl FromRequest<ServerState> for CollectionPostRequest {
CollectionParam,
BsoQueryParams,
BsoBodies,
)>::extract(req).and_then(move |(user_id, db, collection, query, mut bsos)| {
)>::extract(&req).and_then(move |(user_id, db, collection, query, mut bsos)| {
let collection = collection.collection.clone();
if collection == "crypto" {
// Verify the client didn't mess up the crypto if we have a payload
@ -499,12 +501,18 @@ impl FromRequest<ServerState> for CollectionPostRequest {
}
}
let batch = match <Option<BatchRequest>>::extract(&req) {
Ok(batch) => batch,
Err(e) => return future::err(e.into()),
};
future::ok(CollectionPostRequest {
collection,
db,
user_id,
query,
bsos,
batch,
})
});
@ -706,8 +714,8 @@ pub struct BsoQueryParams {
pub offset: Option<u64>,
/// a comma-separated list of BSO ids (list of strings)
#[validate(custom = "validate_qs_ids")]
#[serde(deserialize_with = "deserialize_comma_sep_string", default)]
#[validate(custom = "validate_qs_ids")]
pub ids: Vec<String>,
// flag, whether to include full bodies (bool)
@ -737,6 +745,107 @@ impl FromRequest<ServerState> for BsoQueryParams {
}
}
#[derive(Debug, Default, Clone, Deserialize, Validate)]
#[serde(default)]
pub struct BatchParams {
pub batch: Option<String>,
#[validate(custom = "validate_qs_commit")]
pub commit: Option<String>,
}
#[derive(Debug, Default, Clone, Deserialize)]
pub struct BatchRequest {
pub id: Option<i64>,
pub commit: bool,
}
impl FromRequest<ServerState> for Option<BatchRequest> {
type Config = ();
type Result = ApiResult<Option<BatchRequest>>;
fn from_request(req: &HttpRequest<ServerState>, _: &Self::Config) -> Self::Result {
let params = Query::<BatchParams>::from_request(req, &())
.map_err(|e| {
ValidationErrorKind::FromDetails(
e.to_string(),
RequestErrorLocation::QueryString,
None,
)
})?
.into_inner();
let limits = &req.state().limits;
let checks = [
("X-Weave-Records", limits.max_post_records),
("X-Weave-Bytes", limits.max_post_bytes),
("X-Weave-Total-Records", limits.max_total_records),
("X-Weave-Total-Bytes", limits.max_total_bytes),
];
for (header, limit) in &checks {
let value = match req.headers().get(*header) {
Some(value) => value.to_str().map_err(|e| {
let err: ApiError = ValidationErrorKind::FromDetails(
e.to_string(),
RequestErrorLocation::Header,
Some((*header).to_owned()),
).into();
err
})?,
None => continue,
};
let count = value.parse::<(u32)>().map_err(|_| {
let err: ApiError = ValidationErrorKind::FromDetails(
format!("Invalid integer value: {}", value),
RequestErrorLocation::Header,
Some((*header).to_owned()),
).into();
err
})?;
if count > *limit {
return Err(ValidationErrorKind::FromDetails(
"size-limit-exceeded".to_owned(),
RequestErrorLocation::Header,
None,
).into())
}
}
if params.batch.is_none() && params.commit.is_none() {
// No batch options requested
return Ok(None);
} else if params.batch.is_none() {
// commit w/ no batch ID is an error
let err: DbError = DbErrorKind::BatchNotFound.into();
return Err(err.into());
}
params.validate().map_err(|e| {
ValidationErrorKind::FromValidationErrors(e, RequestErrorLocation::QueryString)
})?;
let id = match params.batch {
None => None,
Some(ref batch) if batch == "" || TRUE_REGEX.is_match(batch) => None,
Some(ref batch) => {
let bytes = base64::decode(batch).unwrap_or(batch.as_bytes().to_vec());
let decoded = std::str::from_utf8(&bytes).unwrap_or(batch);
Some(decoded.parse::<i64>().map_err(|_| {
ValidationErrorKind::FromDetails(
format!(r#"Invalid batch ID: "{}""#, batch),
RequestErrorLocation::QueryString,
Some("batch".to_owned()),
)
})?)
}
};
Ok(Some(BatchRequest {
id,
commit: params.commit.is_some(),
}))
}
}
/// PreCondition Header
///
/// It's valid to include a X-If-Modified-Since or X-If-Unmodified-Since header but not
@ -844,6 +953,17 @@ fn validate_qs_ids(ids: &Vec<String>) -> Result<(), ValidationError> {
Ok(())
}
/// Verifies the batch commit field is valid
fn validate_qs_commit(commit: &String) -> Result<(), ValidationError> {
if !TRUE_REGEX.is_match(commit) {
return Err(request_error(
r#"commit parameter must be "true" to apply batches"#,
RequestErrorLocation::QueryString,
));
}
Ok(())
}
/// Verifies the BSO sortindex is in the valid range
fn validate_body_bso_sortindex(sort: i32) -> Result<(), ValidationError> {
if BSO_MIN_SORTINDEX_VALUE <= sort && sort <= BSO_MAX_SORTINDEX_VALUE {
@ -922,8 +1042,8 @@ mod tests {
use std::sync::Arc;
use actix_web::test::TestRequest;
use actix_web::HttpResponse;
use actix_web::{http::Method, Binary, Body};
use actix_web::{Error, HttpResponse};
use base64;
use hawk::{Credentials, Key, RequestBuilder};
use hmac::{Hmac, Mac};
@ -1005,6 +1125,28 @@ mod tests {
format!("Hawk {}", request.make_header(&credentials).unwrap())
}
fn post_collection(qs: &str, body: &serde_json::Value) -> Result<CollectionPostRequest, Error> {
let payload = HawkPayload::test_default();
let state = make_state();
let path = format!(
"/storage/1.5/1/storage/tabs{}{}",
if !qs.is_empty() { "?" } else { "" },
qs
);
let header = create_valid_hawk_header(&payload, &state, "POST", &path, "localhost", 5000);
let req = TestRequest::with_state(state)
.header("authorization", header)
.header("content-type", "application/json")
.method(Method::POST)
.uri(&format!("http://localhost:5000{}", path))
.set_payload(body.to_string())
.param("uid", "1")
.param("collection", "tabs")
.finish();
req.extensions_mut().insert(make_db());
CollectionPostRequest::extract(&req).wait()
}
#[test]
fn test_invalid_query_args() {
let req = TestRequest::with_state(make_state())
@ -1250,70 +1392,79 @@ mod tests {
#[test]
fn test_valid_collection_post_request() {
let payload = HawkPayload::test_default();
let state = make_state();
let header = create_valid_hawk_header(
&payload,
&state,
"POST",
"/storage/1.5/1/storage/tabs",
"localhost",
5000,
);
// Batch requests require id's on each BSO
let bso_body = json!([
{"id": "123", "payload": "xxx", "sortindex": 23},
{"id": "456", "payload": "xxxasdf", "sortindex": 23}
]);
let req = TestRequest::with_state(state)
.header("authorization", header)
.header("content-type", "application/json")
.method(Method::POST)
.uri("http://localhost:5000/storage/1.5/1/storage/tabs")
.set_payload(bso_body.to_string())
.param("uid", "1")
.param("collection", "tabs")
.finish();
req.extensions_mut().insert(make_db());
let result = CollectionPostRequest::extract(&req).wait().unwrap();
let result = post_collection("", &bso_body).unwrap();
assert_eq!(result.user_id.legacy_id, 1);
assert_eq!(&result.collection, "tabs");
assert_eq!(result.bsos.valid.len(), 2);
assert!(result.batch.is_none());
}
#[test]
fn test_invalid_collection_post_request() {
let payload = HawkPayload::test_default();
let state = make_state();
let header = create_valid_hawk_header(
&payload,
&state,
"POST",
"/storage/1.5/1/storage/tabs",
"localhost",
5000,
);
// Add extra fields, these will be invalid
let bso_body = json!([
{"id": "1", "sortindex": 23, "jump": 1},
{"id": "2", "sortindex": -99, "hop": "low"}
]);
let req = TestRequest::with_state(state)
.header("authorization", header)
.header("content-type", "application/json")
.method(Method::POST)
.uri("http://localhost:5000/storage/1.5/1/storage/tabs")
.set_payload(bso_body.to_string())
.param("uid", "1")
.param("collection", "tabs")
.finish();
req.extensions_mut().insert(make_db());
let result = CollectionPostRequest::extract(&req).wait().unwrap();
let result = post_collection("", &bso_body).unwrap();
assert_eq!(result.user_id.legacy_id, 1);
assert_eq!(&result.collection, "tabs");
assert_eq!(result.bsos.invalid.len(), 2);
}
#[test]
fn test_valid_collection_batch_post_request() {
// If the "batch" parameter is has no value or has a value of "true"
// then a new batch will be created.
let bso_body = json!([
{"id": "123", "payload": "xxx", "sortindex": 23},
{"id": "456", "payload": "xxxasdf", "sortindex": 23}
]);
let result = post_collection("batch=True", &bso_body).unwrap();
assert_eq!(result.user_id.legacy_id, 1);
assert_eq!(&result.collection, "tabs");
assert_eq!(result.bsos.valid.len(), 2);
let batch = result.batch.unwrap();
assert_eq!(batch.id, None);
assert_eq!(batch.commit, false);
let result = post_collection("batch", &bso_body).unwrap();
let batch = result.batch.unwrap();
assert_eq!(batch.id, None);
assert_eq!(batch.commit, false);
let result = post_collection("batch=MTI%3D&commit=true", &bso_body).unwrap();
let batch = result.batch.unwrap();
assert_eq!(batch.id, Some(12));
assert_eq!(batch.commit, true);
}
#[test]
fn test_invalid_collection_batch_post_request() {
let bso_body = json!([
{"id": "123", "payload": "xxx", "sortindex": 23},
{"id": "456", "payload": "xxxasdf", "sortindex": 23}
]);
let result = post_collection("batch=sammich", &bso_body);
assert!(result.is_err());
let response: HttpResponse = result.err().unwrap().into();
assert_eq!(response.status(), 400);
let body = extract_body_as_str(&response);
assert_eq!(body, "0");
let result = post_collection("commit=true", &bso_body);
assert!(result.is_err());
let response: HttpResponse = result.err().unwrap().into();
assert_eq!(response.status(), 400);
let body = extract_body_as_str(&response);
assert_eq!(body, "0");
}
#[test]
fn test_invalid_precondition_headers() {
fn assert_invalid_header(

View File

@ -2,10 +2,10 @@
use std::collections::HashMap;
use actix_web::{http::StatusCode, FutureResponse, HttpResponse, State};
use futures::future::{self, Future};
use futures::future::{self, Either, Future};
use serde::Serialize;
use db::{params, results::Paginated};
use db::{params, results::Paginated, DbError, DbErrorKind};
use error::ApiError;
use server::ServerState;
use web::extractors::{
@ -172,6 +172,9 @@ where
}
pub fn post_collection(coll: CollectionPostRequest) -> FutureResponse<HttpResponse> {
if coll.batch.is_some() {
return post_collection_batch(coll);
}
Box::new(
coll.db
.post_bsos(params::PostBsos {
@ -188,6 +191,118 @@ pub fn post_collection(coll: CollectionPostRequest) -> FutureResponse<HttpRespon
)
}
pub fn post_collection_batch(coll: CollectionPostRequest) -> FutureResponse<HttpResponse> {
// Bail early if we have nonsensical arguments
let breq = match coll.batch.clone() {
Some(breq) => breq,
None => {
let err: DbError = DbErrorKind::BatchNotFound.into();
let err: ApiError = err.into();
return Box::new(future::err(err.into()))
},
};
let fut = if let Some(id) = breq.id {
// Validate the batch before attempting a full append (for efficiency)
Either::A(
coll.db
.validate_batch(params::ValidateBatch {
user_id: coll.user_id.clone(),
collection: coll.collection.clone(),
id,
}).and_then(move |is_valid| {
if is_valid {
Box::new(future::ok(id))
} else {
let err: DbError = DbErrorKind::BatchNotFound.into();
Box::new(future::err(err.into()))
}
}),
)
} else {
Either::B(coll.db.create_batch(params::CreateBatch {
user_id: coll.user_id.clone(),
collection: coll.collection.clone(),
bsos: vec![],
})
)
};
let db = coll.db.clone();
let user_id = coll.user_id.clone();
let collection = coll.collection.clone();
let fut = fut
.and_then(move |id| {
let mut success = vec![];
let mut failed = coll.bsos.invalid.clone();
let bso_ids: Vec<_> = coll.bsos.valid.iter().map(|bso| bso.id.clone()).collect();
coll.db
.append_to_batch(params::AppendToBatch {
user_id: coll.user_id.clone(),
collection: coll.collection.clone(),
id,
bsos: coll.bsos.valid.into_iter().map(From::from).collect(),
})
.then(move |result| {
match result {
Ok(_) => success.extend(bso_ids),
Err(e) => {
// NLL: not a guard as: (E0008) "moves value into
// pattern guard"
if e.is_conflict() {
return future::err(e);
}
failed.extend(
bso_ids.into_iter().map(|id| (id, "db error".to_owned())),
)
}
};
future::ok((id, success, failed))
})
}).map_err(From::from);
Box::new(fut.and_then(move |(id, success, failed)| {
let mut resp = json!({
"success": success,
"failed": failed,
});
if !breq.commit {
resp["batch"] = json!(base64::encode(&id.to_string()));
return Either::A(future::ok(HttpResponse::Accepted().json(resp)));
}
let fut = db
.get_batch(params::GetBatch {
user_id: user_id.clone(),
collection: collection.clone(),
id,
}).and_then(move |batch| {
// TODO: validate *actual* sizes of the batch items
// (max_total_records, max_total_bytes)
if let Some(batch) = batch {
db.commit_batch(params::CommitBatch {
user_id: user_id.clone(),
collection: collection.clone(),
batch,
})
} else {
let err: DbError = DbErrorKind::BatchNotFound.into();
Box::new(future::err(err.into()))
}
}).map_err(From::from)
.map(|result| {
resp["modified"] = json!(result.modified);
HttpResponse::build(StatusCode::OK)
.header("X-Last-Modified", result.modified.as_header())
.json(resp)
});
Either::B(fut)
}))
}
pub fn delete_bso(bso_req: BsoRequest) -> FutureResponse<HttpResponse> {
Box::new(
bso_req