diff --git a/src/settings.rs b/src/settings.rs index aec26503..1fae3bf5 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -192,6 +192,7 @@ impl Settings { s.set_default("tokenserver.fxa_email_domain", "test.com")?; s.set_default("tokenserver.fxa_metrics_hash_secret", "secret")?; s.set_default("tokenserver.test_mode_enabled", false)?; + s.set_default("tokenserver.node_type", "spanner")?; // Set Cors defaults s.set_default( diff --git a/src/tokenserver/extractors.rs b/src/tokenserver/extractors.rs index 7da2bed3..3a762491 100644 --- a/src/tokenserver/extractors.rs +++ b/src/tokenserver/extractors.rs @@ -22,6 +22,7 @@ use sha2::Sha256; use super::db::{self, models::Db, params, results}; use super::error::{ErrorLocation, TokenserverError}; use super::support::TokenData; +use super::NodeType; use super::ServerState; use crate::settings::Secrets; @@ -45,6 +46,7 @@ pub struct TokenserverRequest { pub hashed_device_id: String, pub service_id: i32, pub duration: u64, + pub node_type: NodeType, } impl TokenserverRequest { @@ -236,6 +238,7 @@ impl FromRequest for TokenserverRequest { hashed_device_id, service_id, duration: duration.unwrap_or(DEFAULT_TOKEN_DURATION), + node_type: state.node_type, }; tokenserver_request.validate()?; @@ -502,6 +505,7 @@ mod tests { hashed_device_id: "3a41cccbdd666ebc4199f1f9d1249d44".to_owned(), service_id: db::SYNC_1_5_SERVICE_ID, duration: 100, + node_type: NodeType::default(), }; assert_eq!(result, expected_tokenserver_request); @@ -836,6 +840,7 @@ mod tests { hashed_device_id: "abcdef".to_owned(), service_id: 1, duration: DEFAULT_TOKEN_DURATION, + node_type: NodeType::default(), }; let error = tokenserver_request.validate().unwrap_err(); @@ -868,6 +873,7 @@ mod tests { hashed_device_id: "abcdef".to_owned(), service_id: 1, duration: DEFAULT_TOKEN_DURATION, + node_type: NodeType::default(), }; let error = tokenserver_request.validate().unwrap_err(); @@ -899,6 +905,7 @@ mod tests { hashed_device_id: "abcdef".to_owned(), service_id: 1, duration: DEFAULT_TOKEN_DURATION, + node_type: NodeType::default(), }; let error = tokenserver_request.validate().unwrap_err(); @@ -931,6 +938,7 @@ mod tests { hashed_device_id: "abcdef".to_owned(), service_id: 1, duration: DEFAULT_TOKEN_DURATION, + node_type: NodeType::default(), }; let error = tokenserver_request.validate().unwrap_err(); @@ -963,6 +971,7 @@ mod tests { hashed_device_id: "abcdef".to_owned(), service_id: 1, duration: DEFAULT_TOKEN_DURATION, + node_type: NodeType::default(), }; let error = tokenserver_request.validate().unwrap_err(); @@ -995,6 +1004,7 @@ mod tests { hashed_device_id: "abcdef".to_owned(), service_id: 1, duration: DEFAULT_TOKEN_DURATION, + node_type: NodeType::default(), }; let error = tokenserver_request.validate().unwrap_err(); @@ -1014,6 +1024,7 @@ mod tests { oauth_verifier: Box::new(verifier), db_pool: Box::new(MockTokenserverPool::new()), node_capacity_release_rate: None, + node_type: NodeType::default(), } } } diff --git a/src/tokenserver/handlers.rs b/src/tokenserver/handlers.rs index 274d6c63..2fc3fa60 100644 --- a/src/tokenserver/handlers.rs +++ b/src/tokenserver/handlers.rs @@ -13,6 +13,7 @@ use super::db::params::{GetNodeId, PostUser, PutUser, ReplaceUsers}; use super::error::TokenserverError; use super::extractors::TokenserverRequest; use super::support::{self, Tokenlib}; +use super::NodeType; use crate::tokenserver::support::MakeTokenPlaintext; #[derive(Debug, Serialize)] @@ -23,6 +24,8 @@ pub struct TokenserverResult { api_endpoint: String, duration: u64, hashed_fxa_uid: String, + hashalg: &'static str, + node_type: NodeType, } pub async fn get_tokenserver_result( @@ -57,6 +60,8 @@ pub async fn get_tokenserver_result( api_endpoint: format!("{:}/1.5/{:}", req.user.node, req.user.uid), duration: req.duration, hashed_fxa_uid: req.hashed_fxa_uid, + hashalg: "sha256", + node_type: req.node_type, }; Ok(HttpResponse::build(StatusCode::OK).json(result)) diff --git a/src/tokenserver/mod.rs b/src/tokenserver/mod.rs index b29e0259..daeccf50 100644 --- a/src/tokenserver/mod.rs +++ b/src/tokenserver/mod.rs @@ -8,6 +8,7 @@ pub mod support; pub use self::support::{MockOAuthVerifier, OAuthVerifier, TestModeOAuthVerifier, VerifyToken}; use db::pool::{DbPool, TokenserverPool}; +use serde::{Deserialize, Serialize}; use settings::Settings; use crate::error::ApiError; @@ -19,6 +20,7 @@ pub struct ServerState { pub fxa_metrics_hash_secret: String, pub oauth_verifier: Box, pub node_capacity_release_rate: Option, + pub node_type: NodeType, } impl ServerState { @@ -47,7 +49,22 @@ impl ServerState { oauth_verifier, db_pool: Box::new(db_pool), node_capacity_release_rate: settings.node_capacity_release_rate, + node_type: settings.node_type, }) .map_err(Into::into) } } + +#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Serialize)] +pub enum NodeType { + #[serde(rename = "mysql")] + MySql, + #[serde(rename = "spanner")] + Spanner, +} + +impl Default for NodeType { + fn default() -> Self { + Self::Spanner + } +} diff --git a/src/tokenserver/settings.rs b/src/tokenserver/settings.rs index 465edbed..ed0f238d 100644 --- a/src/tokenserver/settings.rs +++ b/src/tokenserver/settings.rs @@ -1,5 +1,7 @@ use serde::Deserialize; +use super::NodeType; + #[derive(Clone, Debug, Deserialize)] pub struct Settings { pub database_url: String, @@ -31,6 +33,9 @@ pub struct Settings { /// The rate at which capacity should be released from nodes that are at capacity. pub node_capacity_release_rate: Option, + + /// The type of the storage nodes used by this instance of Tokenserver. + pub node_type: NodeType, } impl Default for Settings { @@ -46,6 +51,7 @@ impl Default for Settings { fxa_oauth_server_url: None, test_mode_enabled: false, node_capacity_release_rate: None, + node_type: NodeType::Spanner, } } } diff --git a/tools/integration_tests/tokenserver/test_e2e.py b/tools/integration_tests/tokenserver/test_e2e.py index 721a2b50..ddff93b0 100644 --- a/tools/integration_tests/tokenserver/test_e2e.py +++ b/tools/integration_tests/tokenserver/test_e2e.py @@ -225,4 +225,4 @@ class TestE2e(TestCase, unittest.TestCase): self.assertEqual(res.json['hashalg'], 'sha256') self.assertEqual(res.json['hashed_fxa_uid'], self._fxa_metrics_hash(fxa_uid)[:32]) - self.assertEqual(res.json['node_type'], 'example') + self.assertEqual(res.json['node_type'], 'spanner')