diff --git a/syncstorage/src/settings.rs b/syncstorage/src/settings.rs index 0a4e282c..df1dbb5c 100644 --- a/syncstorage/src/settings.rs +++ b/syncstorage/src/settings.rs @@ -289,6 +289,7 @@ impl Settings { s.set_default("tokenserver.node_type", "spanner")?; s.set_default("tokenserver.statsd_label", "syncstorage.tokenserver")?; s.set_default("tokenserver.run_migrations", cfg!(test))?; + s.set_default("tokenserver.token_duration", 3600)?; // Set Cors defaults s.set_default( diff --git a/syncstorage/src/tokenserver/db/models.rs b/syncstorage/src/tokenserver/db/models.rs index e48e574d..dbd6866a 100644 --- a/syncstorage/src/tokenserver/db/models.rs +++ b/syncstorage/src/tokenserver/db/models.rs @@ -164,7 +164,7 @@ impl TokenserverDb { WHERE service = ? AND email = ? AND generation <= ? - AND COALESCE(keys_changed_at, 0) <= COALESCE(?, 0) + AND COALESCE(keys_changed_at, 0) <= COALESCE(?, keys_changed_at, 0) AND replaced_at IS NULL "#; diff --git a/syncstorage/src/tokenserver/extractors.rs b/syncstorage/src/tokenserver/extractors.rs index bf09ce65..fa10adce 100644 --- a/syncstorage/src/tokenserver/extractors.rs +++ b/syncstorage/src/tokenserver/extractors.rs @@ -31,7 +31,6 @@ lazy_static! { static ref CLIENT_STATE_REGEX: Regex = Regex::new("^[a-zA-Z0-9._-]{1,32}$").unwrap(); } -const DEFAULT_TOKEN_DURATION: u64 = 5 * 60; const SYNC_SERVICE_NAME: &str = "sync-1.5"; /// Information from the request needed to process a Tokenserver request. @@ -98,6 +97,14 @@ impl TokenserverRequest { }); } + // If the client previously reported a client state, every subsequent request must include + // one. Note that this is only relevant for BrowserID requests, since OAuth requests must + // always include a client state. + if !self.user.client_state.is_empty() && self.auth_data.client_state.is_empty() { + let error_message = "Unacceptable client-state value empty string".to_owned(); + return Err(TokenserverError::invalid_client_state(error_message)); + } + // The client state on the request must not have been used in the past. if self .user @@ -147,6 +154,26 @@ impl TokenserverRequest { }); } + // If there's no keys_changed_at on the request, there must be no value stored on the user + // record. Note that this is only relevant for BrowserID requests, since OAuth requests + // must always include a keys_changed_at header. The Python Tokenserver converts a NULL + // keys_changed_at on the user record to 0 in memory, which means that NULL + // keys_changed_ats are treated equivalently to 0 keys_changed_ats. This would allow users + // with a 0 keys_changed_at on their user record to hold off on sending a keys_changed_at + // in requests even though the value in the database is non-NULL. To be thorough, we + // handle this case here. + if auth_keys_changed_at.is_none() + && matches!(user_keys_changed_at, Some(inner) if inner != 0) + { + let context = + "No keys_changed_at sent for a user for whom we've already seen a keys_changed_at" + .to_owned(); + return Err(TokenserverError { + context, + ..TokenserverError::invalid_keys_changed_at() + }); + } + Ok(()) } } @@ -252,7 +279,7 @@ impl FromRequest for TokenserverRequest { match duration_string.parse::() { // The specified token duration should never be greater than the default // token duration set on the server. - Ok(duration) if duration <= DEFAULT_TOKEN_DURATION => Some(duration), + Ok(duration) if duration <= state.token_duration => Some(duration), _ => None, } }) @@ -265,7 +292,7 @@ impl FromRequest for TokenserverRequest { hashed_fxa_uid, hashed_device_id, service_id, - duration: duration.unwrap_or(DEFAULT_TOKEN_DURATION), + duration: duration.unwrap_or(state.token_duration), node_type: state.node_type, }; @@ -413,6 +440,15 @@ impl FromRequest for AuthData { let TokenserverMetrics(mut metrics) = TokenserverMetrics::extract(&req).await?; + // The Python Tokenserver treats zero values and null values both as being + // null, so for consistency, we need to convert a `Some(0)` value to `None` + fn convert_zero_to_none(generation_or_keys_changed_at: Option) -> Option { + match generation_or_keys_changed_at { + Some(0) => None, + _ => generation_or_keys_changed_at, + } + } + match token { Token::BrowserIdAssertion(assertion) => { let mut tags = Tags::default(); @@ -436,8 +472,8 @@ impl FromRequest for AuthData { device_id: verify_output.device_id, email: verify_output.email.clone(), fxa_uid: fxa_uid.to_owned(), - generation: verify_output.generation, - keys_changed_at: verify_output.keys_changed_at, + generation: convert_zero_to_none(verify_output.generation), + keys_changed_at: convert_zero_to_none(verify_output.keys_changed_at), }) } Token::OAuthToken(token) => { @@ -458,8 +494,8 @@ impl FromRequest for AuthData { email, device_id: None, fxa_uid, - generation: verify_output.generation, - keys_changed_at: Some(key_id.keys_changed_at), + generation: convert_zero_to_none(verify_output.generation), + keys_changed_at: convert_zero_to_none(Some(key_id.keys_changed_at)), }) } } @@ -683,6 +719,8 @@ mod tests { static ref SERVER_LIMITS: Arc = Arc::new(ServerLimits::default()); } + const TOKEN_DURATION: u64 = 3600; + #[actix_rt::test] async fn test_valid_tokenserver_request() { let fxa_uid = "test123"; @@ -1064,7 +1102,7 @@ mod tests { hashed_fxa_uid: "abcdef".to_owned(), hashed_device_id: "abcdef".to_owned(), service_id: 1, - duration: DEFAULT_TOKEN_DURATION, + duration: TOKEN_DURATION, node_type: NodeType::default(), }; @@ -1107,7 +1145,7 @@ mod tests { hashed_fxa_uid: "abcdef".to_owned(), hashed_device_id: "abcdef".to_owned(), service_id: 1, - duration: DEFAULT_TOKEN_DURATION, + duration: TOKEN_DURATION, node_type: NodeType::default(), }; @@ -1149,7 +1187,7 @@ mod tests { hashed_fxa_uid: "abcdef".to_owned(), hashed_device_id: "abcdef".to_owned(), service_id: 1, - duration: DEFAULT_TOKEN_DURATION, + duration: TOKEN_DURATION, node_type: NodeType::default(), }; @@ -1192,7 +1230,7 @@ mod tests { hashed_fxa_uid: "abcdef".to_owned(), hashed_device_id: "abcdef".to_owned(), service_id: 1, - duration: DEFAULT_TOKEN_DURATION, + duration: TOKEN_DURATION, node_type: NodeType::default(), }; @@ -1229,7 +1267,7 @@ mod tests { hashed_fxa_uid: "abcdef".to_owned(), hashed_device_id: "abcdef".to_owned(), service_id: 1, - duration: DEFAULT_TOKEN_DURATION, + duration: TOKEN_DURATION, node_type: NodeType::default(), }; @@ -1267,7 +1305,7 @@ mod tests { hashed_fxa_uid: "abcdef".to_owned(), hashed_device_id: "abcdef".to_owned(), service_id: 1, - duration: DEFAULT_TOKEN_DURATION, + duration: TOKEN_DURATION, node_type: NodeType::default(), }; @@ -1303,6 +1341,7 @@ mod tests { ) .unwrap(), ), + token_duration: TOKEN_DURATION, } } } diff --git a/syncstorage/src/tokenserver/handlers.rs b/syncstorage/src/tokenserver/handlers.rs index 8c709e75..5080db2b 100644 --- a/syncstorage/src/tokenserver/handlers.rs +++ b/syncstorage/src/tokenserver/handlers.rs @@ -1,5 +1,4 @@ use std::{ - cmp, collections::HashMap, time::{Duration, SystemTime, UNIX_EPOCH}, }; @@ -84,7 +83,12 @@ fn get_token_plaintext( })?; let client_state_b64 = base64::encode_config(&client_state, base64::URL_SAFE_NO_PAD); - format!("{:013}-{:}", updates.keys_changed_at, client_state_b64) + format!( + "{:013}-{:}", + // We fall back to using the user's generation here, which matches FxA's behavior + updates.keys_changed_at.unwrap_or(updates.generation), + client_state_b64 + ) }; let expires = { @@ -108,7 +112,8 @@ fn get_token_plaintext( } struct UserUpdates { - keys_changed_at: i64, + keys_changed_at: Option, + generation: i64, uid: i64, } @@ -116,20 +121,63 @@ async fn update_user( req: &TokenserverRequest, db: Box, ) -> Result { - // If the keys_changed_at in the request is larger than that stored on the user record, - // update to the value in the request. - let keys_changed_at = - cmp::max(req.auth_data.keys_changed_at, req.user.keys_changed_at).unwrap_or(0); + let keys_changed_at = match (req.auth_data.keys_changed_at, req.user.keys_changed_at) { + // If the keys_changed_at in the request is larger than that stored on the user record, + // update to the value in the request. + (Some(request_keys_changed_at), Some(user_keys_changed_at)) + if request_keys_changed_at >= user_keys_changed_at => + { + Some(request_keys_changed_at) + } + // If there is a keys_changed_at in the request and it's smaller than that stored on the + // user record, we've already returned an error at this point. + (Some(_request_keys_changed_at), Some(_user_keys_changed_at)) => unreachable!(), + // If there is a keys_changed_at on the request but not one on the user record, this is the + // first time the client reported it, so we assign the new value. + (Some(request_keys_changed_at), None) => Some(request_keys_changed_at), + // At this point, we've already validated that, if there is a keys_changed_at already + // stored on the user record, there must be one in the request. If that isn't the case, + // we've already returned an error. + (None, Some(user_keys_changed_at)) if user_keys_changed_at != 0 => unreachable!(), + // If there's no keys_changed_at in the request and the keys_changed_at on the user record + // is 0, keep the value as 0. + (None, Some(_user_keys_changed_at)) => Some(0), + // If there is no keys_changed_at on the user record or in the request, we want to leave + // the value unset. + (None, None) => None, + }; - let generation = if let Some(generation) = req.auth_data.generation { - // If there's a generation on the request, choose the larger of that and the generation - // already stored on the user record. - cmp::max(generation, req.user.generation) - } else { - // If there's not a generation on the request and the keys_changed_at on the request is - // larger than the generation stored on the user record, set the user's generation to be - // the keys_changed_at on the request. - cmp::max(req.auth_data.keys_changed_at, Some(req.user.generation)).unwrap_or(0) + let generation = match req.auth_data.generation { + // If there's a generation in the request and it's greater than or equal to that stored on + // the user record, update to the value in the request. + Some(request_generation) if request_generation >= req.user.generation => request_generation, + // If there's a generation in the request and it's smaller than that stored on the user + // record, we've already returned an error. + Some(_request_generation) => unreachable!(), + None => match (req.auth_data.keys_changed_at, req.user.keys_changed_at) { + // If there's not a generation on the request but the keys_changed_at on the request + // is greater than the user's current generation AND the keys_changed_at on the request + // is greater than the user's current keys_changed_at, set the user's generation to + // the new keys_changed_at. + (Some(request_keys_changed_at), Some(user_keys_changed_at)) + if request_keys_changed_at > user_keys_changed_at + && request_keys_changed_at > req.user.generation => + { + request_keys_changed_at + } + // If there's not a generation on the request but the keys_changed_at on the request + // is greater than the user's current generation AND there is a keys_changed_at on the + // request but not currently on the user record, set the user's generation to the new + // keys_changed_at. + (Some(request_keys_changed_at), None) + if request_keys_changed_at > req.user.generation => + { + request_keys_changed_at + } + // If the request has a keys_changed_at but the above conditions don't hold OR if the + // request doesn't have a keys_changed_at, just keep the same generation. + (_, _) => req.user.generation, + }, }; // If the client state changed, we need to mark the current user as "replaced" and create a @@ -153,7 +201,7 @@ async fn update_user( }) .await? .id, - keys_changed_at: Some(keys_changed_at), + keys_changed_at, created_at: timestamp, }; let uid = db.post_user(post_user_params).await?.id; @@ -169,15 +217,16 @@ async fn update_user( Ok(UserUpdates { keys_changed_at, + generation, uid, }) } else { - if generation != req.user.generation || Some(keys_changed_at) != req.user.keys_changed_at { + if generation != req.user.generation || keys_changed_at != req.user.keys_changed_at { let params = PutUser { email: req.auth_data.email.clone(), service_id: req.service_id, generation, - keys_changed_at: Some(keys_changed_at), + keys_changed_at, }; db.put_user(params).await?; @@ -185,6 +234,7 @@ async fn update_user( Ok(UserUpdates { keys_changed_at, + generation, uid: req.user.uid, }) } diff --git a/syncstorage/src/tokenserver/mod.rs b/syncstorage/src/tokenserver/mod.rs index 5ec2128e..c9b902ba 100644 --- a/syncstorage/src/tokenserver/mod.rs +++ b/syncstorage/src/tokenserver/mod.rs @@ -35,6 +35,7 @@ pub struct ServerState { pub node_capacity_release_rate: Option, pub node_type: NodeType, pub metrics: Box, + pub token_duration: u64, } impl ServerState { @@ -73,6 +74,7 @@ impl ServerState { node_capacity_release_rate: settings.node_capacity_release_rate, node_type: settings.node_type, metrics: Box::new(metrics), + token_duration: settings.token_duration, } }) .map_err(Into::into) diff --git a/syncstorage/src/tokenserver/settings.rs b/syncstorage/src/tokenserver/settings.rs index fa869005..d17d8321 100644 --- a/syncstorage/src/tokenserver/settings.rs +++ b/syncstorage/src/tokenserver/settings.rs @@ -63,6 +63,8 @@ pub struct Settings { /// verifications do not require requests to FXA if the JWK is set on Tokenserver. The server /// will return an error at startup if the JWK is not cached and this setting is `None`. pub additional_blocking_threads_for_fxa_requests: Option, + /// The amount of time in seconds before a token provided by Tokenserver expires. + pub token_duration: u64, } #[derive(Clone, Debug, Deserialize)] @@ -102,6 +104,7 @@ impl Default for Settings { run_migrations: cfg!(test), spanner_node_id: None, additional_blocking_threads_for_fxa_requests: None, + token_duration: 3600, } } } diff --git a/tools/integration_tests/tokenserver/test_authorization.py b/tools/integration_tests/tokenserver/test_authorization.py index c6d1a86d..1b32a3e9 100644 --- a/tools/integration_tests/tokenserver/test_authorization.py +++ b/tools/integration_tests/tokenserver/test_authorization.py @@ -45,7 +45,7 @@ class TestAuthorization(TestCase, unittest.TestCase): self.assertEqual(res.json, expected_error_response) def test_invalid_client_state_in_key_id(self): - if self.AUTH_METHOD == "oauth": + if self.auth_method == "oauth": additional_headers = { 'X-KeyID': "1234-state!" } @@ -277,6 +277,54 @@ class TestAuthorization(TestCase, unittest.TestCase): # This should not result in the creation of a new user self.assertEqual(res.json['uid'], uid) + def test_set_generation_unchanged_without_keys_changed_at_update(self): + # Add a user who has never sent us a generation + uid = self._add_user(generation=0, keys_changed_at=1234, + client_state='aaaa') + # Send a request without a generation that doesn't update + # keys_changed_at + headers = self._build_auth_headers(generation=None, + keys_changed_at=1234, + client_state='aaaa') + self.app.get('/1.0/sync/1.5', headers=headers) + user = self._get_user(uid) + # This should not have set the user's generation + self.assertEqual(user['generation'], 0) + # Send a request without a generation that updates keys_changed_at + headers = self._build_auth_headers(generation=None, + keys_changed_at=1235, + client_state='aaaa') + self.app.get('/1.0/sync/1.5', headers=headers) + user = self._get_user(uid) + # This should have set the user's generation + self.assertEqual(user['generation'], 1235) + + def test_set_generation_with_keys_changed_at_initialization(self): + # Add a user who has never sent us a generation or a keys_changed_at + uid = self._add_user(generation=0, keys_changed_at=None, + client_state='aaaa') + + # Only BrowserID requests can omit keys_changed_at + if self.auth_method == 'browserid': + # Send a request without a generation that doesn't update + # keys_changed_at + headers = self._build_auth_headers(generation=None, + keys_changed_at=None, + client_state='aaaa') + self.app.get('/1.0/sync/1.5', headers=headers) + user = self._get_user(uid) + # This should not have set the user's generation + self.assertEqual(user['generation'], 0) + + # Send a request without a generation that updates keys_changed_at + headers = self._build_auth_headers(generation=None, + keys_changed_at=1234, + client_state='aaaa') + self.app.get('/1.0/sync/1.5', headers=headers) + user = self._get_user(uid) + # This should have set the user's generation + self.assertEqual(user['generation'], 1234) + def test_fxa_kid_change(self): self._add_user(generation=1234, keys_changed_at=None, client_state='aaaa') @@ -337,12 +385,12 @@ class TestAuthorization(TestCase, unittest.TestCase): self.assertEquals(res.json['duration'], 12) # But you can't exceed the server's default value. res = self.app.get('/1.0/sync/1.5?duration=4000', headers=headers) - self.assertEquals(res.json['duration'], 300) + self.assertEquals(res.json['duration'], 3600) # And nonsense values are ignored. res = self.app.get('/1.0/sync/1.5?duration=lolwut', headers=headers) - self.assertEquals(res.json['duration'], 300) + self.assertEquals(res.json['duration'], 3600) res = self.app.get('/1.0/sync/1.5?duration=-1', headers=headers) - self.assertEquals(res.json['duration'], 300) + self.assertEquals(res.json['duration'], 3600) # Although all servers are now writing keys_changed_at, we still need this # case to be handled. See this PR for more information: @@ -486,7 +534,7 @@ class TestAuthorization(TestCase, unittest.TestCase): self.assertEqual(user['keys_changed_at'], 1234) def test_x_client_state_must_have_same_client_state_as_key_id(self): - if self.AUTH_METHOD == "oauth": + if self.auth_method == "oauth": self._add_user(client_state='aaaa') additional_headers = {'X-Client-State': 'bbbb'} headers = self._build_auth_headers(generation=1234, @@ -509,3 +557,31 @@ class TestAuthorization(TestCase, unittest.TestCase): self.assertEqual(res.json, expected_error_response) headers['X-Client-State'] = 'aaaa' res = self.app.get('/1.0/sync/1.5', headers=headers) + + def test_zero_generation_treated_as_null(self): + # Add a user that has a generation set + uid = self._add_user(generation=1234, keys_changed_at=1234, + client_state='aaaa') + headers = self._build_auth_headers(generation=0, + keys_changed_at=1234, + client_state='aaaa') + # Send a request with a generation of 0 + self.app.get('/1.0/sync/1.5', headers=headers) + # Ensure that the request succeeded and that the user's generation + # was not updated + user = self._get_user(uid) + self.assertEqual(user['generation'], 1234) + + def test_zero_keys_changed_at_treated_as_null(self): + # Add a user that has no keys_changed_at set + uid = self._add_user(generation=1234, keys_changed_at=None, + client_state='aaaa') + headers = self._build_auth_headers(generation=1234, + keys_changed_at=0, + client_state='aaaa') + # Send a request with a keys_changed_at of 0 + self.app.get('/1.0/sync/1.5', headers=headers) + # Ensure that the request succeeded and that the user's + # keys_changed_at was not updated + user = self._get_user(uid) + self.assertEqual(user['keys_changed_at'], None) diff --git a/tools/integration_tests/tokenserver/test_browserid.py b/tools/integration_tests/tokenserver/test_browserid.py index 0e3b0941..80921840 100644 --- a/tools/integration_tests/tokenserver/test_browserid.py +++ b/tools/integration_tests/tokenserver/test_browserid.py @@ -474,3 +474,89 @@ class TestBrowserId(TestCase, unittest.TestCase): client_state="aaaa") res = self.app.get("/1.0/sync/1.5", headers=headers, status=401) self.assertEqual(res.json["status"], "invalid-generation") + + def test_reverting_to_no_keys_changed_at(self): + # Add a user that has no keys_changed_at set + uid = self._add_user(generation=0, keys_changed_at=None, + client_state='aaaa') + # Send a request with keys_changed_at + headers = self._build_browserid_headers(generation=None, + keys_changed_at=1234, + client_state='aaaa') + self.app.get('/1.0/sync/1.5', headers=headers) + user = self._get_user(uid) + # Confirm that keys_changed_at was set + self.assertEqual(user['keys_changed_at'], 1234) + # Send a request with no keys_changed_at + headers = self._build_browserid_headers(generation=None, + keys_changed_at=None, + client_state='aaaa') + # Once a keys_changed_at has been set, the server expects to receive + # it from that point onwards + res = self.app.get('/1.0/sync/1.5', headers=headers, status=401) + expected_error_response = { + 'status': 'invalid-keysChangedAt', + 'errors': [ + { + 'location': 'body', + 'name': '', + 'description': 'Unauthorized', + } + ] + } + self.assertEqual(res.json, expected_error_response) + + def test_zero_keys_changed_at_treated_as_null(self): + # Add a user that has a zero keys_changed_at + uid = self._add_user(generation=0, keys_changed_at=0, + client_state='aaaa') + # Send a request with no keys_changed_at + headers = self._build_browserid_headers(generation=None, + keys_changed_at=None, + client_state='aaaa') + self.app.get('/1.0/sync/1.5', headers=headers) + # The request should succeed and the keys_changed_at should be + # unchanged + user = self._get_user(uid) + self.assertEqual(user['keys_changed_at'], 0) + + def test_reverting_to_no_client_state(self): + # Add a user that has no client_state + uid = self._add_user(generation=0, keys_changed_at=None, + client_state="") + # Send a request with no client state + headers = self._build_browserid_headers(generation=None, + keys_changed_at=None, + client_state=None) + # The request should succeed + self.app.get('/1.0/sync/1.5', headers=headers) + # Send a request that updates the client state + headers = self._build_browserid_headers(generation=None, + keys_changed_at=None, + client_state='aaaa') + # The request should succeed + res = self.app.get('/1.0/sync/1.5', headers=headers) + user = self._get_user(res.json['uid']) + # A new user should have been created + self.assertNotEqual(uid, res.json['uid']) + # The client state should have been updated + self.assertEqual(user['client_state'], 'aaaa') + # Send another request with no client state + headers = self._build_browserid_headers(generation=None, + keys_changed_at=None, + client_state=None) + # The request should fail, since we are trying to revert to using no + # client state after setting one + res = self.app.get('/1.0/sync/1.5', headers=headers, status=401) + expected_error_response = { + 'status': 'invalid-client-state', + 'errors': [ + { + 'location': 'header', + 'name': 'X-Client-State', + 'description': 'Unacceptable client-state value empty ' + 'string', + } + ] + } + self.assertEqual(res.json, expected_error_response) diff --git a/tools/integration_tests/tokenserver/test_e2e.py b/tools/integration_tests/tokenserver/test_e2e.py index 91753743..90899e20 100644 --- a/tools/integration_tests/tokenserver/test_e2e.py +++ b/tools/integration_tests/tokenserver/test_e2e.py @@ -25,7 +25,7 @@ from tokenserver.test_support import TestCase # this is the proper client ID to be using for these integration tests. BROWSERID_AUDIENCE = "https://token.stage.mozaws.net" CLIENT_ID = '5882386c6d801776' -DEFAULT_TOKEN_DURATION = 300 +DEFAULT_TOKEN_DURATION = 3600 FXA_ACCOUNT_STAGE_HOST = 'https://api-accounts.stage.mozaws.net' FXA_OAUTH_STAGE_HOST = 'https://oauth.stage.mozaws.net' PASSWORD_CHARACTERS = string.ascii_letters + string.punctuation + string.digits diff --git a/tools/integration_tests/tokenserver/test_misc.py b/tools/integration_tests/tokenserver/test_misc.py index 38c1bdb7..c82d629d 100644 --- a/tools/integration_tests/tokenserver/test_misc.py +++ b/tools/integration_tests/tokenserver/test_misc.py @@ -57,7 +57,7 @@ class TestMisc(TestCase, unittest.TestCase): res = self.app.get('/1.0/sync/1.5', headers=headers) self.assertIn('https://example.com/1.5', res.json['api_endpoint']) self.assertIn('duration', res.json) - self.assertEquals(res.json['duration'], 300) + self.assertEquals(res.json['duration'], 3600) def test_current_user_is_the_most_up_to_date(self): # Add some users diff --git a/tools/integration_tests/tokenserver/test_support.py b/tools/integration_tests/tokenserver/test_support.py index 6f8086f3..4e849d21 100644 --- a/tools/integration_tests/tokenserver/test_support.py +++ b/tools/integration_tests/tokenserver/test_support.py @@ -17,7 +17,6 @@ DEFAULT_OAUTH_SCOPE = 'https://identity.mozilla.com/apps/oldsync' class TestCase: - AUTH_METHOD = os.environ.get('TOKENSERVER_AUTH_METHOD', 'oauth') BROWSERID_ISSUER = os.environ['SYNC_TOKENSERVER__FXA_BROWSERID_ISSUER'] FXA_EMAIL_DOMAIN = 'api-accounts.stage.mozaws.net' FXA_METRICS_HASH_SECRET = 'secret0' @@ -26,6 +25,15 @@ class TestCase: TOKEN_SIGNING_SECRET = 'secret0' TOKENSERVER_HOST = os.environ['TOKENSERVER_HOST'] + @classmethod + def setUpClass(cls): + cls.auth_method = os.environ['TOKENSERVER_AUTH_METHOD'] + + if cls.auth_method == 'browserid': + cls._build_auth_headers = cls._build_browserid_headers + else: + cls._build_auth_headers = cls._build_oauth_headers + def setUp(self): engine = create_engine(os.environ['SYNC_TOKENSERVER__DATABASE_URL']) self.database = engine. \ @@ -41,11 +49,6 @@ class TestCase: 'SCRIPT_NAME': host_url.path, }) - if self.AUTH_METHOD == 'browserid': - self._build_auth_headers = self._build_browserid_headers - else: - self._build_auth_headers = self._build_oauth_headers - # Start each test with a blank slate. cursor = self._execute_sql(('DELETE FROM users'), ()) cursor.close() @@ -80,6 +83,10 @@ class TestCase: 'client_id': 'fake client id', 'scope': [DEFAULT_OAUTH_SCOPE], } + + if generation is not None: + claims['generation'] = generation + body = { 'body': claims, 'status': status @@ -105,8 +112,8 @@ class TestCase: 'issuer': issuer } - if device_id or generation or keys_changed_at or \ - token_verified is not None: + if device_id or generation is not None or \ + keys_changed_at is not None or token_verified is not None: idp_claims = {} if device_id: @@ -130,9 +137,11 @@ class TestCase: headers = { 'Authorization': 'BrowserID %s' % json.dumps(body), - 'X-Client-State': client_state } + if client_state: + headers['X-Client-State'] = client_state + headers.update(additional_headers) return headers