diff --git a/src/main.rs b/src/main.rs index 3b848aa5..37c02cf5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,7 +47,13 @@ async fn main() -> Result<(), Box> { // Setup and run the server let banner = settings.banner(); - let server = server::Server::with_settings(settings).await.unwrap(); + let server = if settings.disable_syncstorage { + server::Server::tokenserver_only_with_settings(settings) + .await + .unwrap() + } else { + server::Server::with_settings(settings).await.unwrap() + }; info!("Server running on {}", banner); server.await?; info!("Server closing"); diff --git a/src/server/mod.rs b/src/server/mod.rs index e338ac53..8f3e72f4 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -13,7 +13,7 @@ use tokio::sync::RwLock; use crate::db::{pool_from_settings, spawn_pool_periodic_reporter, DbPool}; use crate::error::ApiError; use crate::server::metrics::Metrics; -use crate::settings::{Deadman, Secrets, ServerLimits, Settings}; +use crate::settings::{Deadman, ServerLimits, Settings}; use crate::tokenserver; use crate::web::{handlers, middleware}; @@ -38,14 +38,6 @@ pub struct ServerState { /// limits rendered as JSON pub limits_json: String, - /// Secrets used during Hawk authentication. - pub secrets: Arc, - - // XXX: This is only any Option temporarily. Once Tokenserver is rolled out to production, - // it will always be enabled, and syncstorage will always have state associated with - // Tokenserver. - pub tokenserver_state: Option, - /// Metric reporting pub metrics: Box, @@ -70,9 +62,11 @@ pub struct Server; #[macro_export] macro_rules! build_app { - ($state: expr, $limits: expr) => { + ($syncstorage_state: expr, $tokenserver_state: expr, $secrets: expr, $limits: expr) => { App::new() - .data($state) + .data($syncstorage_state) + .data($tokenserver_state) + .data($secrets) // Middleware is applied LIFO // These will wrap all outbound responses with matching status codes. .wrap(ErrorHandlers::new().handler(StatusCode::NOT_FOUND, ApiError::render_404)) @@ -171,16 +165,65 @@ macro_rules! build_app { }; } +#[macro_export] +macro_rules! build_app_without_syncstorage { + ($state: expr, $secrets: expr) => { + App::new() + .data($state) + .data($secrets) + // Middleware is applied LIFO + // These will wrap all outbound responses with matching status codes. + .wrap(ErrorHandlers::new().handler(StatusCode::NOT_FOUND, ApiError::render_404)) + // These are our wrappers + .wrap(middleware::sentry::SentryWrapper::default()) + .wrap(middleware::rejectua::RejectUA::default()) + // Followed by the "official middleware" so they run first. + // actix is getting increasingly tighter about CORS headers. Our server is + // not a huge risk but does deliver XHR JSON content. + // For now, let's be permissive and use NGINX (the wrapping server) + // for finer grained specification. + .wrap(Cors::permissive()) + .service( + web::resource("/1.0/{application}/{version}") + .route(web::get().to(tokenserver::handlers::get_tokenserver_result)), + ) + // Dockerflow + // Remember to update .::web::middleware::DOCKER_FLOW_ENDPOINTS + // when applying changes to endpoint names. + .service( + web::resource("/__heartbeat__") + .route(web::get().to(tokenserver::handlers::heartbeat)), + ) + .service( + web::resource("/__lbheartbeat__").route(web::get().to(|_: HttpRequest| { + // used by the load balancers, just return OK. + HttpResponse::Ok() + .content_type("application/json") + .body("{}") + })), + ) + .service( + web::resource("/__version__").route(web::get().to(|_: HttpRequest| { + // return the contents of the version.json file created by circleci + // and stored in the docker root + HttpResponse::Ok() + .content_type("application/json") + .body(include_str!("../../version.json")) + })), + ) + }; +} + impl Server { pub async fn with_settings(settings: Settings) -> Result { let metrics = metrics::metrics_from_opts(&settings)?; + let host = settings.host.clone(); + let port = settings.port; let db_pool = pool_from_settings(&settings, &Metrics::from(&metrics)).await?; let limits = Arc::new(settings.limits); let limits_json = serde_json::to_string(&*limits).expect("ServerLimits failed to serialize"); let secrets = Arc::new(settings.master_secret); - let host = settings.host.clone(); - let port = settings.port; let quota_enabled = settings.enable_quota; let actix_keep_alive = settings.actix_keep_alive; let deadman = Arc::new(RwLock::new(Deadman { @@ -198,24 +241,47 @@ impl Server { spawn_pool_periodic_reporter(Duration::from_secs(10), metrics.clone(), db_pool.clone())?; let mut server = HttpServer::new(move || { - // Setup the server state - let state = ServerState { + let syncstorage_state = ServerState { db_pool: db_pool.clone(), limits: Arc::clone(&limits), limits_json: limits_json.clone(), - secrets: Arc::clone(&secrets), - tokenserver_state: tokenserver_state.clone(), metrics: Box::new(metrics.clone()), port, quota_enabled, deadman: Arc::clone(&deadman), }; - build_app!(state, limits) + build_app!( + syncstorage_state, + tokenserver_state.clone(), + Arc::clone(&secrets), + limits + ) }); + if let Some(keep_alive) = actix_keep_alive { server = server.keep_alive(keep_alive as usize); } + + let server = server + .bind(format!("{}:{}", host, port)) + .expect("Could not get Server in Server::with_settings") + .run(); + Ok(server) + } + + pub async fn tokenserver_only_with_settings( + settings: Settings, + ) -> Result { + let host = settings.host.clone(); + let port = settings.port; + let secrets = Arc::new(settings.master_secret); + let tokenserver_state = tokenserver::ServerState::from_settings(&settings.tokenserver)?; + + let server = HttpServer::new(move || { + build_app_without_syncstorage!(Some(tokenserver_state.clone()), Arc::clone(&secrets)) + }); + let server = server .bind(format!("{}:{}", host, port)) .expect("Could not get Server in Server::with_settings") diff --git a/src/server/test.rs b/src/server/test.rs index 37b621d5..7c66340b 100644 --- a/src/server/test.rs +++ b/src/server/test.rs @@ -24,6 +24,7 @@ use crate::db::pool_from_settings; use crate::db::results::{DeleteBso, GetBso, PostBsos, PutBso}; use crate::db::util::SyncTimestamp; use crate::settings::{test_settings, Secrets, ServerLimits}; +use crate::tokenserver; use crate::web::{auth::HawkPayload, extractors::BsoBody, X_LAST_MODIFIED}; lazy_static! { @@ -68,8 +69,6 @@ async fn get_test_state(settings: &Settings) -> ServerState { .expect("Could not get db_pool in get_test_state"), limits: Arc::clone(&SERVER_LIMITS), limits_json: serde_json::to_string(&**SERVER_LIMITS).unwrap(), - secrets: Arc::clone(&SECRETS), - tokenserver_state: None, metrics: Box::new(metrics), port: settings.port, quota_enabled: settings.enable_quota, @@ -91,7 +90,13 @@ macro_rules! init_app { async { crate::logging::init_logging(false).unwrap(); let limits = Arc::new($settings.limits.clone()); - test::init_service(build_app!(get_test_state(&$settings).await, limits)).await + test::init_service(build_app!( + get_test_state(&$settings).await, + None::, + Arc::clone(&SECRETS), + limits + )) + .await } }; } @@ -207,7 +212,13 @@ where { let settings = get_test_settings(); let limits = Arc::new(settings.limits.clone()); - let mut app = test::init_service(build_app!(get_test_state(&settings).await, limits)).await; + let mut app = test::init_service(build_app!( + get_test_state(&settings).await, + None::, + Arc::clone(&SECRETS), + limits + )) + .await; let req = create_request(method, path, None, None).to_request(); let sresponse = match app.call(req).await { @@ -241,7 +252,13 @@ async fn test_endpoint_with_body( ) -> Bytes { let settings = get_test_settings(); let limits = Arc::new(settings.limits.clone()); - let mut app = test::init_service(build_app!(get_test_state(&settings).await, limits)).await; + let mut app = test::init_service(build_app!( + get_test_state(&settings).await, + None::, + Arc::clone(&SECRETS), + limits + )) + .await; let req = create_request(method, path, None, Some(body)).to_request(); let sresponse = app .call(req) diff --git a/src/settings.rs b/src/settings.rs index 95ef65d4..4006085c 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -80,6 +80,10 @@ pub struct Settings { pub spanner_emulator_host: Option, + /// Disable all of the endpoints related to syncstorage. To be used when running Tokenserver + /// in isolation. + pub disable_syncstorage: bool, + /// Settings specific to Tokenserver pub tokenserver: TokenserverSettings, } @@ -108,6 +112,7 @@ impl Default for Settings { enable_quota: false, enforce_quota: false, spanner_emulator_host: None, + disable_syncstorage: false, tokenserver: TokenserverSettings::default(), } } @@ -160,6 +165,7 @@ impl Settings { s.set_default("statsd_label", "syncstorage")?; s.set_default("enable_quota", false)?; s.set_default("enforce_quota", false)?; + s.set_default("disable_syncstorage", false)?; // Set Tokenserver defaults s.set_default( diff --git a/src/tokenserver/db/mock.rs b/src/tokenserver/db/mock.rs index dde2035e..bbbe27df 100644 --- a/src/tokenserver/db/mock.rs +++ b/src/tokenserver/db/mock.rs @@ -53,6 +53,10 @@ impl Db for MockDb { Box::pin(future::ok(())) } + fn check(&self) -> DbFuture<'_, results::Check> { + Box::pin(future::ok(true)) + } + #[cfg(test)] fn set_user_created_at( &self, diff --git a/src/tokenserver/db/models.rs b/src/tokenserver/db/models.rs index 9fa5a0ee..0bd1b6e8 100644 --- a/src/tokenserver/db/models.rs +++ b/src/tokenserver/db/models.rs @@ -225,6 +225,12 @@ impl TokenserverDb { .map_err(Into::into) } + fn check_sync(&self) -> DbResult { + // has the database been up for more than 0 seconds? + let result = diesel::sql_query("SHOW STATUS LIKE \"Uptime\"").execute(&self.inner.conn)?; + Ok(result as u64 > 0) + } + fn get_timestamp_in_milliseconds() -> i64 { SystemTime::now() .duration_since(UNIX_EPOCH) @@ -310,6 +316,11 @@ impl Db for TokenserverDb { sync_db_method!(post_user, post_user_sync, PostUser); sync_db_method!(put_user, put_user_sync, PutUser); + fn check(&self) -> DbFuture<'_, results::Check> { + let db = self.clone(); + Box::pin(block(move || db.check_sync().map_err(Into::into)).map_err(Into::into)) + } + #[cfg(test)] sync_db_method!( set_user_created_at, @@ -336,6 +347,8 @@ pub trait Db { fn put_user(&self, params: params::PutUser) -> DbFuture<'_, results::PutUser>; + fn check(&self) -> DbFuture<'_, results::Check>; + #[cfg(test)] fn set_user_created_at( &self, diff --git a/src/tokenserver/db/results.rs b/src/tokenserver/db/results.rs index dc97231f..7f3aa7fd 100644 --- a/src/tokenserver/db/results.rs +++ b/src/tokenserver/db/results.rs @@ -72,3 +72,5 @@ pub type SetUserCreatedAt = (); #[cfg(test)] pub type GetUsers = Vec; + +pub type Check = bool; diff --git a/src/tokenserver/extractors.rs b/src/tokenserver/extractors.rs index 31848135..c9ad5553 100644 --- a/src/tokenserver/extractors.rs +++ b/src/tokenserver/extractors.rs @@ -3,6 +3,8 @@ //! Handles ensuring the header's, body, and query parameters are correct, extraction to //! relevant types, and failing correctly with the appropriate errors if issues arise. +use std::sync::Arc; + use actix_web::{ dev::Payload, web::{Data, Query}, @@ -17,7 +19,8 @@ use sha2::Sha256; use super::db::{self, models::Db, params, results}; use super::error::TokenserverError; use super::support::TokenData; -use crate::server::ServerState; +use super::ServerState; +use crate::settings::Secrets; const DEFAULT_TOKEN_DURATION: u64 = 5 * 60; @@ -46,15 +49,9 @@ impl FromRequest for TokenserverRequest { Box::pin(async move { let token_data = TokenData::extract(&req).await?; - let state = get_server_state(&req)?; - let tokenserver_state = state.tokenserver_state.as_ref().unwrap(); - let fxa_metrics_hash_secret = &tokenserver_state.fxa_metrics_hash_secret.as_bytes(); - let shared_secret = - String::from_utf8(state.secrets.master_secret.clone()).map_err(|_| { - error!("⚠️ Failed to read master secret"); - - TokenserverError::internal_error() - })?; + let state = get_server_state(&req)?.as_ref().as_ref().unwrap(); + let shared_secret = get_secret(&req)?; + let fxa_metrics_hash_secret = &state.fxa_metrics_hash_secret.as_bytes(); let key_id = KeyId::extract(&req).await?; let fxa_uid = token_data.user; let hashed_fxa_uid = { @@ -90,12 +87,12 @@ impl FromRequest for TokenserverRequest { } }; let user = { - let db = tokenserver_state.db_pool.get().map_err(|_| { + let db = state.db_pool.get().map_err(|_| { error!("⚠️ Could not acquire database connection"); TokenserverError::internal_error() })?; - let email = format!("{}@{}", fxa_uid, tokenserver_state.fxa_email_domain); + let email = format!("{}@{}", fxa_uid, state.fxa_email_domain); db.get_user(params::GetUser { email, service_id }).await? }; @@ -146,9 +143,8 @@ impl FromRequest for Box { let req = req.clone(); Box::pin(async move { - let state = get_server_state(&req)?; - let tokenserver_state = state.tokenserver_state.as_ref().unwrap(); - let db = tokenserver_state.db_pool.get().map_err(|_| { + let state = get_server_state(&req)?.as_ref().as_ref().unwrap(); + let db = state.db_pool.get().map_err(|_| { error!("⚠️ Could not acquire database connection"); TokenserverError::internal_error() @@ -186,12 +182,12 @@ impl FromRequest for TokenData { let auth = BearerAuth::extract(&req) .await .map_err(|_| TokenserverError::invalid_credentials("Unsupported"))?; - let state = get_server_state(&req)?; - // XXX: tokenserver_state will no longer be an Option once the Tokenserver + // XXX: The Tokenserver state will no longer be an Option once the Tokenserver // code is rolled out, so we will eventually be able to remove this unwrap(). - let tokenserver_state = state.tokenserver_state.as_ref().unwrap(); - tokenserver_state + let state = get_server_state(&req)?.as_ref().as_ref().unwrap(); + + state .oauth_verifier .verify_token(auth.token()) .map_err(Into::into) @@ -258,14 +254,28 @@ impl FromRequest for KeyId { } } -fn get_server_state(req: &HttpRequest) -> Result<&Data, Error> { - req.app_data::>().ok_or_else(|| { +fn get_server_state(req: &HttpRequest) -> Result<&Data>, Error> { + req.app_data::>>().ok_or_else(|| { error!("⚠️ Could not load the app state"); TokenserverError::internal_error().into() }) } +fn get_secret(req: &HttpRequest) -> Result { + let secrets = req.app_data::>>().ok_or_else(|| { + error!("⚠️ Could not load the app secrets"); + + Error::from(TokenserverError::internal_error()) + })?; + + String::from_utf8(secrets.master_secret.clone()).map_err(|_| { + error!("⚠️ Failed to read master secret"); + + TokenserverError::internal_error().into() + }) +} + fn fxa_metrics_hash(fxa_uid: &str, hmac_key: &[u8]) -> String { let mut mac = Hmac::::new_from_slice(hmac_key).expect("HMAC has no key size limit"); mac.update(fxa_uid.as_bytes()); @@ -295,13 +305,10 @@ mod tests { use futures::executor::block_on; use lazy_static::lazy_static; use serde_json; - use tokio::sync::RwLock; - use crate::db::mock::MockDbPool; - use crate::server::{metrics, ServerState}; - use crate::settings::{Deadman, Secrets, ServerLimits, Settings}; + use crate::settings::{Secrets, ServerLimits}; use crate::tokenserver::{ - self, db::mock::MockDbPool as MockTokenserverPool, MockOAuthVerifier, + db::mock::MockDbPool as MockTokenserverPool, MockOAuthVerifier, ServerState, }; use std::sync::Arc; @@ -330,7 +337,8 @@ mod tests { let state = make_state(verifier); let req = TestRequest::default() - .data(state) + .data(Some(state)) + .data(Arc::clone(&SECRETS)) .header("authorization", "Bearer fake_token") .header("accept", "application/json,text/plain:q=0.5") .header("x-keyid", "0000000001234-YWFh") @@ -380,7 +388,8 @@ mod tests { let state = make_state(verifier); let request = TestRequest::default() - .data(state) + .data(Some(state)) + .data(Arc::clone(&SECRETS)) .header("authorization", "Bearer fake_token") .header("accept", "application/json,text/plain:q=0.5") .header("x-keyid", "0000000001234-YWFh") @@ -421,7 +430,8 @@ mod tests { }; TestRequest::default() - .data(make_state(verifier)) + .data(Some(make_state(verifier))) + .data(Arc::clone(&SECRETS)) .header("authorization", "Bearer fake_token") .header("accept", "application/json,text/plain:q=0.5") .header("x-keyid", "0000000001234-YWFh") @@ -530,7 +540,7 @@ mod tests { }; TestRequest::default() - .data(make_state(verifier)) + .data(Some(make_state(verifier))) .header("authorization", "Bearer fake_token") .header("accept", "application/json,text/plain:q=0.5") .param("application", "sync") @@ -663,24 +673,11 @@ mod tests { } fn make_state(verifier: MockOAuthVerifier) -> ServerState { - let settings = Settings::default(); - let tokenserver_state = tokenserver::ServerState { + ServerState { fxa_email_domain: "test.com".to_owned(), fxa_metrics_hash_secret: "".to_owned(), oauth_verifier: Box::new(verifier), db_pool: Box::new(MockTokenserverPool::new()), - }; - - ServerState { - db_pool: Box::new(MockDbPool::new()), - limits: Arc::clone(&SERVER_LIMITS), - limits_json: serde_json::to_string(&**SERVER_LIMITS).unwrap(), - secrets: Arc::clone(&SECRETS), - tokenserver_state: Some(tokenserver_state), - port: 8000, - metrics: Box::new(metrics::metrics_from_opts(&settings).unwrap()), - quota_enabled: settings.enable_quota, - deadman: Arc::new(RwLock::new(Deadman::default())), } } } diff --git a/src/tokenserver/handlers.rs b/src/tokenserver/handlers.rs index 4e2d3a37..453086d1 100644 --- a/src/tokenserver/handlers.rs +++ b/src/tokenserver/handlers.rs @@ -1,4 +1,7 @@ -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::{ + sync::Arc, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; use actix_web::http::StatusCode; use actix_web::web::Data; @@ -6,15 +9,18 @@ use actix_web::Error; use actix_web::{HttpRequest, HttpResponse}; use hmac::{Hmac, Mac, NewMac}; use serde::Serialize; +use serde_json::Value; use sha2::Sha256; +use std::collections::HashMap; -use super::db::{self, params::GetUser}; +use super::db::{self, models::Db, params::GetUser}; use super::extractors::TokenserverRequest; use super::support::Tokenlib; -use crate::tokenserver::support::MakeTokenPlaintext; +use super::ServerState; use crate::{ error::{ApiError, ApiErrorKind}, - server::ServerState, + settings::Secrets, + tokenserver::support::MakeTokenPlaintext, }; #[derive(Debug, Serialize)] @@ -32,18 +38,17 @@ pub async fn get_tokenserver_result( request: HttpRequest, ) -> Result { let state = request - .app_data::>() - .ok_or_else(|| internal_error("Could not load the app state"))?; - let tokenserver_state = state.tokenserver_state.as_ref().unwrap(); + .app_data::>>() + .ok_or_else(|| internal_error("Could not load the app state"))? + .as_ref() + .as_ref() + .unwrap(); let db = { - let db_pool = tokenserver_state.db_pool.clone(); + let db_pool = state.db_pool.clone(); db_pool.get().map_err(ApiError::from)? }; - let user_email = format!( - "{}@{}", - tokenserver_request.fxa_uid, tokenserver_state.fxa_email_domain - ); + let user_email = format!("{}@{}", tokenserver_request.fxa_uid, state.fxa_email_domain); let tokenserver_user = { let params = GetUser { email: user_email.clone(), @@ -53,10 +58,7 @@ pub async fn get_tokenserver_result( db.get_user(params).await? }; - let fxa_metrics_hash_secret = tokenserver_state - .fxa_metrics_hash_secret - .clone() - .into_bytes(); + let fxa_metrics_hash_secret = state.fxa_metrics_hash_secret.clone().into_bytes(); let hashed_fxa_uid_full = fxa_metrics_hash(&tokenserver_request.fxa_uid, &fxa_metrics_hash_secret)?; @@ -80,8 +82,14 @@ pub async fn get_tokenserver_result( }; let (token, derived_secret) = { - let shared_secret = String::from_utf8(state.secrets.master_secret.clone()) - .map_err(|_| internal_error("Failed to read master secret"))?; + let shared_secret = String::from_utf8( + request + .app_data::>>() + .ok_or_else(|| internal_error("Could not load the app secrets"))? + .master_secret + .clone(), + ) + .map_err(|_| internal_error("Failed to read master secret"))?; let make_token_plaintext = { let expires = { @@ -142,3 +150,34 @@ fn internal_error(message: &str) -> HttpResponse { HttpResponse::InternalServerError().body("") } + +pub async fn heartbeat(db: Box) -> Result { + let mut checklist = HashMap::new(); + checklist.insert( + "version".to_owned(), + Value::String(env!("CARGO_PKG_VERSION").to_owned()), + ); + + match db.check().await { + Ok(result) => { + if result { + checklist.insert("database".to_owned(), Value::from("Ok")); + } else { + checklist.insert("database".to_owned(), Value::from("Err")); + checklist.insert( + "database_msg".to_owned(), + Value::from("check failed without error"), + ); + }; + let status = if result { "Ok" } else { "Err" }; + checklist.insert("status".to_owned(), Value::from(status)); + Ok(HttpResponse::Ok().json(checklist)) + } + Err(e) => { + error!("Heartbeat error: {:?}", e); + checklist.insert("status".to_owned(), Value::from("Err")); + checklist.insert("database".to_owned(), Value::from("Unknown")); + Ok(HttpResponse::ServiceUnavailable().json(checklist)) + } + } +} diff --git a/src/web/extractors.rs b/src/web/extractors.rs index df5cffbd..7f36d07f 100644 --- a/src/web/extractors.rs +++ b/src/web/extractors.rs @@ -2,7 +2,9 @@ //! //! Handles ensuring the header's, body, and query parameters are correct, extraction to //! relevant types, and failing correctly with the appropriate errors if issues arise. -use std::{self, collections::HashMap, collections::HashSet, num::ParseIntError, str::FromStr}; +use std::{ + self, collections::HashMap, collections::HashSet, num::ParseIntError, str::FromStr, sync::Arc, +}; use actix_web::{ dev::{ConnectionInfo, Extensions, Payload, RequestHead}, @@ -1041,7 +1043,7 @@ impl HawkIdentifier { method: &str, uri: &Uri, ci: &ConnectionInfo, - state: &ServerState, + secrets: &Secrets, ) -> Result where T: HttpMessage, @@ -1056,7 +1058,7 @@ impl HawkIdentifier { .ok_or_else(|| -> ApiError { HawkErrorKind::MissingHeader.into() })? .to_str() .map_err(|e| -> ApiError { HawkErrorKind::Header(e).into() })?; - let identifier = Self::generate(&state.secrets, method, auth_header, ci, uri)?; + let identifier = Self::generate(secrets, method, auth_header, ci, uri)?; msg.extensions_mut().insert(identifier.clone()); Ok(identifier) } @@ -1099,14 +1101,14 @@ impl FromRequest for HawkIdentifier { let req = req.clone(); Box::pin(async move { - let state = match req.app_data::>() { + let secrets = match req.app_data::>>() { Some(s) => s, None => { - error!("⚠️ Could not load the app state"); + error!("⚠️ Could not load the app secrets"); return Err(ValidationErrorKind::FromDetails( "Internal error".to_owned(), RequestErrorLocation::Unknown, - Some("state".to_owned()), + Some("secrets".to_owned()), None, ) .into()); @@ -1116,7 +1118,7 @@ impl FromRequest for HawkIdentifier { let connection_info = req.connection_info().clone(); let method = req.method().as_str(); let uri = req.uri(); - Self::extrude(&req, method, uri, &connection_info, state) + Self::extrude(&req, method, uri, &connection_info, secrets) }) } } @@ -1722,8 +1724,6 @@ mod tests { db_pool: Box::new(MockDbPool::new()), limits: Arc::clone(&SERVER_LIMITS), limits_json: serde_json::to_string(&**SERVER_LIMITS).unwrap(), - secrets: Arc::clone(&SECRETS), - tokenserver_state: None, port: 8000, metrics: Box::new(metrics::metrics_from_opts(&settings).unwrap()), quota_enabled: settings.enable_quota, @@ -1737,7 +1737,7 @@ mod tests { fn create_valid_hawk_header( payload: &HawkPayload, - state: &ServerState, + secrets: &Secrets, method: &str, path: &str, host: &str, @@ -1745,7 +1745,7 @@ mod tests { ) -> String { let salt = payload.salt.clone(); let payload = serde_json::to_string(payload).unwrap(); - let mut hmac = Hmac::::new_from_slice(&state.secrets.signing_secret).unwrap(); + let mut hmac = Hmac::::new_from_slice(&secrets.signing_secret).unwrap(); hmac.update(payload.as_bytes()); let payload_hash = hmac.finalize().into_bytes(); let mut id = payload.as_bytes().to_vec(); @@ -1774,6 +1774,7 @@ mod tests { ) -> Result { let payload = HawkPayload::test_default(*USER_ID); let state = make_state(); + let secrets = Arc::clone(&SECRETS); let path = format!( "/1.5/{}/storage/tabs{}{}", *USER_ID, @@ -1782,9 +1783,10 @@ mod tests { ); let bod_str = body.to_string(); let header = - create_valid_hawk_header(&payload, &state, "POST", &path, TEST_HOST, TEST_PORT); + create_valid_hawk_header(&payload, &secrets, "POST", &path, TEST_HOST, TEST_PORT); let req = TestRequest::with_uri(&format!("http://{}:{}{}", TEST_HOST, TEST_PORT, path)) .data(state) + .data(secrets) .method(Method::POST) .header("authorization", header) .header("content-type", "application/json; charset=UTF-8") @@ -1882,10 +1884,13 @@ mod tests { fn test_valid_bso_request() { let payload = HawkPayload::test_default(*USER_ID); let state = make_state(); + let secrets = Arc::clone(&SECRETS); let uri = format!("/1.5/{}/storage/tabs/asdf", *USER_ID); - let header = create_valid_hawk_header(&payload, &state, "GET", &uri, TEST_HOST, TEST_PORT); + let header = + create_valid_hawk_header(&payload, &secrets, "GET", &uri, TEST_HOST, TEST_PORT); let req = TestRequest::with_uri(&uri) .data(state) + .data(secrets) .header("authorization", header) .method(Method::GET) .param("uid", &USER_ID_STR) @@ -1904,10 +1909,13 @@ mod tests { fn test_invalid_bso_request() { let payload = HawkPayload::test_default(*USER_ID); let state = make_state(); + let secrets = Arc::clone(&SECRETS); let uri = format!("/1.5/{}/storage/tabs/{}", *USER_ID, INVALID_BSO_NAME); - let header = create_valid_hawk_header(&payload, &state, "GET", &uri, TEST_HOST, TEST_PORT); + let header = + create_valid_hawk_header(&payload, &secrets, "GET", &uri, TEST_HOST, TEST_PORT); let req = TestRequest::with_uri(&uri) .data(state) + .data(secrets) .header("authorization", header) .method(Method::GET) // `param` sets the value that would be extracted from the tokenized URI, as if the router did it. @@ -1938,13 +1946,16 @@ mod tests { fn test_valid_bso_post_body() { let payload = HawkPayload::test_default(*USER_ID); let state = make_state(); + let secrets = Arc::clone(&SECRETS); let uri = format!("/1.5/{}/storage/tabs/asdf", *USER_ID); - let header = create_valid_hawk_header(&payload, &state, "POST", &uri, TEST_HOST, TEST_PORT); + let header = + create_valid_hawk_header(&payload, &secrets, "POST", &uri, TEST_HOST, TEST_PORT); let bso_body = json!({ "id": "128", "payload": "x" }); let req = TestRequest::with_uri(&uri) .data(state) + .data(secrets) .header("authorization", header) .header("content-type", "application/json") .method(Method::POST) @@ -1968,13 +1979,16 @@ mod tests { fn test_invalid_bso_post_body() { let payload = HawkPayload::test_default(*USER_ID); let state = make_state(); + let secrets = Arc::clone(&SECRETS); let uri = format!("/1.5/{}/storage/tabs/asdf", *USER_ID); - let header = create_valid_hawk_header(&payload, &state, "POST", &uri, TEST_HOST, TEST_PORT); + let header = + create_valid_hawk_header(&payload, &secrets, "POST", &uri, TEST_HOST, TEST_PORT); let bso_body = json!({ "payload": "xxx", "sortindex": -9_999_999_999_i64, }); let req = TestRequest::with_uri(&uri) .data(state) + .data(secrets) .header("authorization", header) .header("content-type", "application/json") .method(Method::POST) @@ -2006,10 +2020,13 @@ mod tests { fn test_valid_collection_request() { let payload = HawkPayload::test_default(*USER_ID); let state = make_state(); + let secrets = Arc::clone(&SECRETS); let uri = format!("/1.5/{}/storage/tabs", *USER_ID); - let header = create_valid_hawk_header(&payload, &state, "GET", &uri, TEST_HOST, TEST_PORT); + let header = + create_valid_hawk_header(&payload, &secrets, "GET", &uri, TEST_HOST, TEST_PORT); let req = TestRequest::with_uri(&uri) .data(state) + .data(secrets) .header("authorization", header) .header("accept", "application/json,text/plain:q=0.5") .method(Method::GET) @@ -2028,14 +2045,17 @@ mod tests { let payload = HawkPayload::test_default(*USER_ID); let altered_bso = format!("\"{{{}}}\"", *USER_ID); let state = make_state(); + let secrets = Arc::clone(&SECRETS); let uri = format!( "/1.5/{}/storage/tabs/{}", *USER_ID, urlencoding::encode(&altered_bso) ); - let header = create_valid_hawk_header(&payload, &state, "GET", &uri, TEST_HOST, TEST_PORT); + let header = + create_valid_hawk_header(&payload, &secrets, "GET", &uri, TEST_HOST, TEST_PORT); let req = TestRequest::with_uri(&uri) .data(state) + .data(secrets) .header("authorization", header) .header("accept", "application/json,text/plain:q=0.5") .method(Method::GET) @@ -2052,13 +2072,15 @@ mod tests { fn test_invalid_collection_request() { let hawk_payload = HawkPayload::test_default(*USER_ID); let state = make_state(); + let secrets = Arc::clone(&SECRETS); let uri = format!("/1.5/{}/storage/{}", *USER_ID, INVALID_COLLECTION_NAME); let header = - create_valid_hawk_header(&hawk_payload, &state, "GET", &uri, TEST_HOST, TEST_PORT); + create_valid_hawk_header(&hawk_payload, &secrets, "GET", &uri, TEST_HOST, TEST_PORT); let req = TestRequest::with_uri(&uri) .header("authorization", header) .method(Method::GET) .data(state) + .data(secrets) .param("uid", &USER_ID_STR) .param("collection", INVALID_COLLECTION_NAME) .to_http_request(); @@ -2240,13 +2262,15 @@ mod tests { fn valid_header_with_valid_path() { let hawk_payload = HawkPayload::test_default(*USER_ID); let state = make_state(); + let secrets = Arc::clone(&SECRETS); let uri = format!("/1.5/{}/storage/col2", *USER_ID); let header = - create_valid_hawk_header(&hawk_payload, &state, "GET", &uri, TEST_HOST, TEST_PORT); + create_valid_hawk_header(&hawk_payload, &secrets, "GET", &uri, TEST_HOST, TEST_PORT); let req = TestRequest::with_uri(&uri) .header("authorization", header) .method(Method::GET) .data(state) + .data(secrets) .param("uid", &USER_ID_STR) .to_http_request(); let mut payload = Payload::None; @@ -2261,11 +2285,13 @@ mod tests { let hawk_payload = HawkPayload::test_default(*USER_ID); let mismatch_uid = "5"; let state = make_state(); + let secrets = Arc::clone(&SECRETS); let uri = format!("/1.5/{}/storage/col2", mismatch_uid); let header = - create_valid_hawk_header(&hawk_payload, &state, "GET", &uri, TEST_HOST, TEST_PORT); + create_valid_hawk_header(&hawk_payload, &secrets, "GET", &uri, TEST_HOST, TEST_PORT); let req = TestRequest::with_uri(&uri) .data(state) + .data(secrets) .header("authorization", header) .method(Method::GET) .param("uid", mismatch_uid) diff --git a/src/web/handlers.rs b/src/web/handlers.rs index 4b26f9b3..488a2835 100644 --- a/src/web/handlers.rs +++ b/src/web/handlers.rs @@ -6,7 +6,7 @@ use actix_web::{ dev::HttpResponseBuilder, http::StatusCode, web::Data, Error, HttpRequest, HttpResponse, }; use serde::Serialize; -use serde_json::{json, Value}; +use serde_json::{json, Map, Value}; use crate::{ db::{ @@ -18,6 +18,7 @@ use crate::{ }, error::{ApiError, ApiErrorKind, ApiResult}, server::ServerState, + tokenserver, web::{ extractors::{ BsoPutRequest, BsoRequest, CollectionPostRequest, CollectionRequest, HeartbeatRequest, @@ -544,7 +545,7 @@ pub fn get_configuration(state: Data) -> HttpResponse { /** Returns a status message indicating the state of the current server * */ -pub async fn heartbeat(hb: HeartbeatRequest) -> Result { +pub async fn heartbeat(hb: HeartbeatRequest, req: HttpRequest) -> Result { let mut checklist = HashMap::new(); checklist.insert( "version".to_owned(), @@ -554,6 +555,44 @@ pub async fn heartbeat(hb: HeartbeatRequest) -> Result { checklist.insert("quota".to_owned(), serde_json::to_value(hb.quota)?); + let tokenserver_state = match req.app_data::>>() { + Some(s) => s, + None => { + error!("⚠️ Could not load the app state"); + return Ok(HttpResponse::InternalServerError().body("")); + } + }; + + let mut tokenserver_service_unavailable = false; + if let Some(tokenserver_state) = tokenserver_state.as_ref() { + let db = tokenserver_state.db_pool.get().map_err(ApiError::from)?; + let mut tokenserver_checklist = Map::new(); + + match db.check().await { + Ok(result) => { + if result { + tokenserver_checklist.insert("database".to_owned(), Value::from("Ok")); + } else { + tokenserver_checklist.insert("database".to_owned(), Value::from("Err")); + tokenserver_checklist.insert( + "database_msg".to_owned(), + Value::from("check failed without error"), + ); + }; + let status = if result { "Ok" } else { "Err" }; + tokenserver_checklist.insert("status".to_owned(), Value::from(status)); + } + Err(e) => { + error!("Heartbeat error: {:?}", e); + tokenserver_checklist.insert("status".to_owned(), Value::from("Err")); + tokenserver_checklist.insert("database".to_owned(), Value::from("Unknown")); + tokenserver_service_unavailable = true; + } + } + + checklist.insert("tokenserver".to_owned(), Value::from(tokenserver_checklist)); + } + match db.check().await { Ok(result) => { if result { @@ -567,7 +606,12 @@ pub async fn heartbeat(hb: HeartbeatRequest) -> Result { }; let status = if result { "Ok" } else { "Err" }; checklist.insert("status".to_owned(), Value::from(status)); - Ok(HttpResponse::Ok().json(checklist)) + + if tokenserver_service_unavailable { + Ok(HttpResponse::ServiceUnavailable().json(checklist)) + } else { + Ok(HttpResponse::Ok().json(checklist)) + } } Err(e) => { error!("Heartbeat error: {:?}", e); diff --git a/src/web/middleware/mod.rs b/src/web/middleware/mod.rs index ba330710..80a44d19 100644 --- a/src/web/middleware/mod.rs +++ b/src/web/middleware/mod.rs @@ -6,11 +6,13 @@ pub mod weave; // // Matches the [Sync Storage middleware](https://github.com/mozilla-services/server-syncstorage/blob/master/syncstorage/tweens.py) (tweens). +use std::sync::Arc; + use actix_web::{dev::ServiceRequest, Error, HttpRequest}; use crate::db::util::SyncTimestamp; use crate::error::{ApiError, ApiErrorKind}; -use crate::server::ServerState; +use crate::settings::Secrets; use crate::web::{extractors::HawkIdentifier, DOCKER_FLOW_ENDPOINTS}; use actix_web::web::Data; @@ -30,10 +32,12 @@ impl SyncServerRequest for ServiceRequest { // NOTE: `connection_info()` gets a mutable reference lock on `extensions()`, so // it must be cloned let ci = &self.connection_info().clone(); - let state = &self.app_data::().ok_or_else(|| -> ApiError { - ApiErrorKind::Internal("No app_data ServerState".to_owned()).into() - })?; - HawkIdentifier::extrude(self, method.as_str(), self.uri(), ci, state) + let secrets = &self + .app_data::>>() + .ok_or_else(|| -> ApiError { + ApiErrorKind::Internal("No app_data Secrets".to_owned()).into() + })?; + HawkIdentifier::extrude(self, method.as_str(), self.uri(), ci, secrets) } } @@ -46,11 +50,11 @@ impl SyncServerRequest for HttpRequest { // NOTE: `connection_info()` gets a mutable reference lock on `extensions()`, so // it must be cloned let ci = &self.connection_info().clone(); - let state = &self - .app_data::>() + let secrets = &self + .app_data::>>() .ok_or_else(|| -> ApiError { - ApiErrorKind::Internal("No app_data ServerState".to_owned()).into() + ApiErrorKind::Internal("No app_data Secrets".to_owned()).into() })?; - HawkIdentifier::extrude(self, method.as_str(), self.uri(), ci, state) + HawkIdentifier::extrude(self, method.as_str(), self.uri(), ci, secrets) } }