diff --git a/src/tokenserver/extractors.rs b/src/tokenserver/extractors.rs index 01f59d92..70d7cb11 100644 --- a/src/tokenserver/extractors.rs +++ b/src/tokenserver/extractors.rs @@ -7,12 +7,15 @@ use std::sync::Arc; use actix_web::{ dev::Payload, + http::StatusCode, web::{Data, Query}, Error, FromRequest, HttpRequest, }; use actix_web_httpauth::extractors::bearer::BearerAuth; use futures::future::LocalBoxFuture; use hmac::{Hmac, Mac, NewMac}; +use lazy_static::lazy_static; +use regex::Regex; use serde::Deserialize; use sha2::Sha256; @@ -22,6 +25,10 @@ use super::support::TokenData; use super::ServerState; use crate::settings::Secrets; +lazy_static! { + static ref CLIENT_STATE_REGEX: Regex = Regex::new("^[a-zA-Z0-9._-]{1,32}$").unwrap(); +} + const DEFAULT_TOKEN_DURATION: u64 = 5 * 60; /// Information from the request needed to process a Tokenserver request. @@ -279,9 +286,9 @@ impl FromRequest for TokenData { let authorization_header = req .headers() .get("Authorization") - .ok_or_else(|| TokenserverError::invalid_credentials("Unauthorized"))? + .ok_or_else(|| TokenserverError::unauthorized("Unauthorized"))? .to_str() - .map_err(|_| TokenserverError::invalid_credentials("Unauthorized"))?; + .map_err(|_| TokenserverError::unauthorized("Unauthorized"))?; // The request must use Bearer auth if let Some((auth_type, _)) = authorization_header.split_once(" ") { @@ -322,6 +329,24 @@ impl FromRequest for KeyId { Box::pin(async move { let headers = req.headers(); + let maybe_x_client_state = headers + .get("X-Client-State") + .and_then(|header| header.to_str().ok()); + + // If there's a client state value in the X-Client-State header, make sure it is valid + if let Some(x_client_state) = maybe_x_client_state { + if !CLIENT_STATE_REGEX.is_match(x_client_state) { + return Err(TokenserverError { + status: "error", + location: ErrorLocation::Header, + description: "Invalid client state value", + name: "X-Client-State".to_owned(), + http_status: StatusCode::BAD_REQUEST, + } + .into()); + } + } + let x_key_id = headers .get("X-KeyId") .ok_or_else(|| TokenserverError::invalid_key_id("Missing X-KeyID header"))? @@ -341,9 +366,6 @@ impl FromRequest for KeyId { // If there's a client state value in the X-Client-State header, verify that it matches // the value in X-KeyID. - let maybe_x_client_state = headers - .get("X-Client-State") - .and_then(|header| header.to_str().ok()); if let Some(x_client_state) = maybe_x_client_state { if x_client_state != client_state { return Err(TokenserverError { diff --git a/tools/integration_tests/tokenserver/test_authorization.py b/tools/integration_tests/tokenserver/test_authorization.py index fa3dcbbf..5b10caf0 100644 --- a/tools/integration_tests/tokenserver/test_authorization.py +++ b/tools/integration_tests/tokenserver/test_authorization.py @@ -30,11 +30,7 @@ class TestAuthorization(TestCase, unittest.TestCase): self.assertEqual(res.json, expected_error_response) def test_no_auth(self): - self.app.get('/1.0/sync/1.5', status=401) - - def test_invalid_client_state(self): - headers = {'X-KeyID': '1234-state!'} - resp = self.app.get('/1.0/sync/1.5', headers=headers, status=401) + res = self.app.get('/1.0/sync/1.5', status=401) expected_error_response = { 'status': 'error', @@ -46,7 +42,46 @@ class TestAuthorization(TestCase, unittest.TestCase): } ] } - self.assertEqual(resp.json, expected_error_response) + self.assertEqual(res.json, expected_error_response) + + def test_invalid_client_state_in_key_id(self): + headers = { + 'Authorization': 'Bearer %s' % self._forge_oauth_token(), + 'X-KeyID': '1234-state!' + } + res = self.app.get('/1.0/sync/1.5', headers=headers, status=401) + + expected_error_response = { + 'status': 'invalid-credentials', + 'errors': [ + { + 'location': 'body', + 'name': '', + 'description': 'Unauthorized' + } + ] + } + self.assertEqual(res.json, expected_error_response) + + def test_invalid_client_state_in_x_client_state(self): + headers = { + 'Authorization': 'Bearer %s' % self._forge_oauth_token(), + 'X-KeyID': '1234-YWFh', + 'X-Client-State': 'state!' + } + res = self.app.get('/1.0/sync/1.5', headers=headers, status=400) + + expected_error_response = { + 'status': 'error', + 'errors': [ + { + 'location': 'header', + 'name': 'X-Client-State', + 'description': 'Invalid client state value' + } + ] + } + self.assertEqual(res.json, expected_error_response) def test_keys_changed_at_less_than_equal_to_generation(self): self._add_user(generation=1232, keys_changed_at=1234)