feat: Puts pyo3 behind feature flag and derives tokens directly in Rust (#1513)

* Removes pyo3 and derives tokens directly in Rust

* Adds tests for JWT verifying

* Adds tests for token generation

* Adds metrics for oauth verify error cases

* Updates jsonwebtoken to not include default features (including pem loading)

* Adds context and logs errors during oauth verify

* Uses ring for cryptographic rng

* Adds back python impl under feature flag

* Uses one cached http client for reqwest
This commit is contained in:
Tarik Eshaq 2024-02-12 11:14:15 -05:00 committed by GitHub
parent d544a0e378
commit 1b11684648
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1321 additions and 295 deletions

View File

@ -48,13 +48,13 @@ commands:
- run:
name: Rust Clippy MySQL
command: |
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql -- -D warnings
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- -D warnings
rust-clippy-spanner:
steps:
- run:
name: Rust Clippy Spanner
command: |
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner -- -D warnings
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- -D warnings
cargo-build:
steps:
- run:

26
Cargo.lock generated
View File

@ -1068,8 +1068,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi",
"wasm-bindgen",
]
[[package]]
@ -1439,6 +1441,19 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "jsonwebtoken"
version = "9.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c7ea04a7c5c055c175f189b6dc6ba036fd62306b58c66c9f6389036c503a3f4"
dependencies = [
"base64",
"js-sys",
"ring",
"serde 1.0.195",
"serde_json",
]
[[package]]
name = "language-tags"
version = "0.3.2"
@ -3032,14 +3047,23 @@ name = "tokenserver-auth"
version = "0.14.4"
dependencies = [
"async-trait",
"base64",
"dyn-clone",
"futures 0.3.30",
"hex",
"hkdf",
"hmac",
"jsonwebtoken",
"mockito",
"pyo3",
"reqwest",
"ring",
"serde 1.0.195",
"serde_json",
"sha2",
"slog-scope",
"syncserver-common",
"thiserror",
"tokenserver-common",
"tokenserver-settings",
"tokio",
@ -3051,6 +3075,7 @@ version = "0.14.4"
dependencies = [
"actix-web",
"backtrace",
"jsonwebtoken",
"serde 1.0.195",
"serde_json",
"syncserver-common",
@ -3086,6 +3111,7 @@ dependencies = [
name = "tokenserver-settings"
version = "0.14.4"
dependencies = [
"jsonwebtoken",
"serde 1.0.195",
"tokenserver-common",
]

View File

@ -38,7 +38,10 @@ docopt = "1.1"
env_logger = "0.10"
futures = { version = "0.3", features = ["compat"] }
hex = "0.4"
hkdf = "0.12"
hmac = "0.12"
http = "0.2"
jsonwebtoken = { version = "9.2", default-features = false }
lazy_static = "1.4"
protobuf = "=2.25.2" # pin to 2.25.2 to prevent side updating
rand = "0.8"
@ -62,6 +65,7 @@ slog-scope = "4.3"
slog-stdlog = "4.1"
slog-term = "2.6"
tokio = "1"
thiserror = "1.0.26"
[profile.release]
# Enables line numbers in Sentry reporting

View File

@ -19,7 +19,7 @@ RUN \
apt-get -q install -y --no-install-recommends libmysqlclient-dev cmake
COPY --from=planner /app/recipe.json recipe.json
RUN cargo chef cook --release --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --recipe-path recipe.json
RUN cargo chef cook --release --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --features=py_verifier --recipe-path recipe.json
FROM chef as builder
ARG DATABASE_BACKEND=spanner
@ -46,7 +46,7 @@ ENV PATH=$PATH:/root/.cargo/bin
RUN \
cargo --version && \
rustc --version && \
cargo install --path ./syncserver --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --locked --root /app && \
cargo install --path ./syncserver --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --features=py_verifier --locked --root /app && \
if [ "$DATABASE_BACKEND" = "spanner" ] ; then cargo install --path ./syncstorage-spanner --locked --root /app --bin purge_ttl ; fi
FROM docker.io/library/debian:bullseye-slim
@ -56,11 +56,12 @@ COPY --from=builder /app/requirements.txt /app
# have to set this env var to prevent the cryptography package from building
# with Rust. See this link for more information:
# https://pythonshowcase.com/question/problem-installing-cryptography-on-raspberry-pi
ENV CRYPTOGRAPHY_DONT_BUILD_RUST=1
RUN \
apt-get -q update && apt-get -qy install wget
ENV CRYPTOGRAPHY_DONT_BUILD_RUST=1
RUN \
groupadd --gid 10001 app && \
useradd --uid 10001 --gid 10001 --home /app --create-home app && \

View File

@ -15,11 +15,11 @@ PYTHON_SITE_PACKGES = $(shell $(SRC_ROOT)/venv/bin/python -c "from distutils.sys
clippy_mysql:
# Matches what's run in circleci
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql -- -D warnings
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- -D warnings
clippy_spanner:
# Matches what's run in circleci
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner -- -D warnings
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- -D warnings
clean:
cargo clean
@ -47,14 +47,15 @@ python:
python3 -m venv venv
venv/bin/python -m pip install -r requirements.txt
run_mysql: python
PATH="./venv/bin:$(PATH)" \
# See https://github.com/PyO3/pyo3/issues/1741 for discussion re: why we need to set the
# below env var
PYTHONPATH=$(PYTHON_SITE_PACKGES) \
RUST_LOG=debug \
RUST_LOG=debug \
RUST_BACKTRACE=full \
cargo run --no-default-features --features=syncstorage-db/mysql -- --config config/local.toml
cargo run --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- --config config/local.toml
run_spanner: python
GOOGLE_APPLICATION_CREDENTIALS=$(PATH_TO_SYNC_SPANNER_KEYS) \
@ -65,7 +66,7 @@ run_spanner: python
PATH="./venv/bin:$(PATH)" \
RUST_LOG=debug \
RUST_BACKTRACE=full \
cargo run --no-default-features --features=syncstorage-db/spanner -- --config config/local.toml
cargo run --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- --config config/local.toml
test:
SYNC_SYNCSTORAGE__DATABASE_URL=mysql://sample_user:sample_password@localhost/syncstorage_rs \

View File

@ -14,5 +14,5 @@ serde_json.workspace = true
slog.workspace = true
slog-scope.workspace = true
actix-web.workspace = true
hkdf.workspace = true
hkdf = "0.12"

View File

@ -9,9 +9,9 @@ edition.workspace=true
backtrace.workspace=true
futures.workspace=true
http.workspace=true
thiserror.workspace=true
deadpool = { git = "https://github.com/mozilla-services/deadpool", tag = "deadpool-v0.7.0" }
diesel = { version = "1.4", features = ["mysql", "r2d2"] }
diesel_migrations = { version = "1.4.0", features = ["mysql"] }
syncserver-common = { path = "../syncserver-common" }
thiserror = "1.0.26"

View File

@ -32,6 +32,8 @@ slog-mozlog-json.workspace = true
slog-scope.workspace = true
slog-stdlog.workspace = true
slog-term.workspace = true
hmac.workspace = true
thiserror.workspace = true
actix-http = "3"
actix-rt = "2"
@ -40,7 +42,6 @@ async-trait = "0.1.40"
dyn-clone = "1.0.4"
hostname = "0.3.1"
hawk = "5.0"
hmac = "0.12"
mime = "0.3"
reqwest = { workspace = true, features = [
"json",
@ -53,8 +54,7 @@ syncserver-settings = { path = "../syncserver-settings" }
syncstorage-db = { path = "../syncstorage-db" }
syncstorage-settings = { path = "../syncstorage-settings" }
time = "^0.3"
thiserror = "1.0.26"
tokenserver-auth = { path = "../tokenserver-auth" }
tokenserver-auth = { path = "../tokenserver-auth", default-features = false}
tokenserver-common = { path = "../tokenserver-common" }
tokenserver-db = { path = "../tokenserver-db" }
tokenserver-settings = { path = "../tokenserver-settings" }
@ -65,7 +65,8 @@ validator_derive = "0.16"
woothee = "0.13"
[features]
default = ["mysql"]
default = ["mysql", "py_verifier"]
no_auth = []
py_verifier = ["tokenserver-auth/py"]
mysql = ["syncstorage-db/mysql"]
spanner = ["syncstorage-db/spanner"]

View File

@ -457,7 +457,8 @@ impl FromRequest for AuthData {
let mut tags = HashMap::default();
tags.insert("token_type".to_owned(), "BrowserID".to_owned());
metrics.start_timer("token_verification", Some(tags));
let verify_output = state.browserid_verifier.verify(assertion).await?;
let verify_output =
state.browserid_verifier.verify(assertion, &metrics).await?;
// For requests using BrowserID, the client state is embedded in the
// X-Client-State header, and the generation and keys_changed_at are extracted
@ -487,7 +488,7 @@ impl FromRequest for AuthData {
let mut tags = HashMap::default();
tags.insert("token_type".to_owned(), "OAuth".to_owned());
metrics.start_timer("token_verification", Some(tags));
let verify_output = state.oauth_verifier.verify(token).await?;
let verify_output = state.oauth_verifier.verify(token, &metrics).await?;
// For requests using OAuth, the keys_changed_at and client state are embedded
// in the X-KeyID header.

View File

@ -9,6 +9,8 @@ use serde::{
Serialize,
};
use syncserver_common::{BlockingThreadpool, Metrics};
#[cfg(not(feature = "py_verifier"))]
use tokenserver_auth::JWTVerifierImpl;
use tokenserver_auth::{browserid, oauth, VerifyToken};
use tokenserver_common::NodeType;
use tokenserver_db::{params, DbPool, TokenserverPool};
@ -40,6 +42,32 @@ impl ServerState {
metrics: Arc<StatsdClient>,
blocking_threadpool: Arc<BlockingThreadpool>,
) -> Result<Self, ApiError> {
#[cfg(not(feature = "py_verifier"))]
let oauth_verifier = {
let mut jwk_verifiers: Vec<JWTVerifierImpl> = Vec::new();
if let Some(primary) = &settings.fxa_oauth_primary_jwk {
jwk_verifiers.push(
primary
.clone()
.try_into()
.expect("Invalid primary key, should either be fixed or removed"),
)
}
if let Some(secondary) = &settings.fxa_oauth_secondary_jwk {
jwk_verifiers.push(
secondary
.clone()
.try_into()
.expect("Invalid secondary key, should either be fixed or removed"),
);
}
Box::new(
oauth::Verifier::new(settings, jwk_verifiers)
.expect("failed to create Tokenserver OAuth verifier"),
)
};
#[cfg(feature = "py_verifier")]
let oauth_verifier = Box::new(
oauth::Verifier::new(settings, blocking_threadpool.clone())
.expect("failed to create Tokenserver OAuth verifier"),

View File

@ -13,10 +13,10 @@ lazy_static.workspace=true
http.workspace=true
serde.workspace=true
serde_json.workspace=true
thiserror.workspace=true
async-trait = "0.1.40"
diesel = { version = "1.4", features = ["mysql", "r2d2"] }
diesel_migrations = { version = "1.4.0", features = ["mysql"] }
syncserver-common = { path = "../syncserver-common" }
syncserver-db-common = { path = "../syncserver-db-common" }
thiserror = "1.0.26"

View File

@ -11,6 +11,7 @@ base64.workspace=true
futures.workspace=true
http.workspace=true
slog-scope.workspace=true
thiserror.workspace=true
async-trait = "0.1.40"
diesel = { version = "1.4", features = ["mysql", "r2d2"] }
@ -20,7 +21,6 @@ syncserver-common = { path = "../syncserver-common" }
syncserver-db-common = { path = "../syncserver-db-common" }
syncstorage-db-common = { path = "../syncstorage-db-common" }
syncstorage-settings = { path = "../syncstorage-settings" }
thiserror = "1.0.26"
url = "2.1"
[dev-dependencies]

View File

@ -12,6 +12,7 @@ env_logger.workspace = true
futures.workspace = true
http.workspace = true
slog-scope.workspace = true
thiserror.workspace = true
async-trait = "0.1.40"
google-cloud-rust-raw = { version = "0.16.1", features = ["spanner"] }
@ -30,7 +31,6 @@ syncserver-common = { path = "../syncserver-common" }
syncserver-db-common = { path = "../syncserver-db-common" }
syncstorage-db-common = { path = "../syncstorage-db-common" }
syncstorage-settings = { path = "../syncstorage-settings" }
thiserror = "1.0.26"
tokio = { workspace = true, features = [
"macros",
"sync",

View File

@ -11,15 +11,30 @@ edition.workspace = true
futures.workspace = true
serde.workspace = true
serde_json.workspace = true
hex.workspace = true
hkdf.workspace = true
hmac.workspace = true
jsonwebtoken.workspace = true
base64.workspace = true
sha2.workspace = true
thiserror.workspace = true
slog-scope.workspace = true
async-trait = "0.1.40"
dyn-clone = "1.0.4"
pyo3 = { version = "0.20", features = ["auto-initialize"] }
reqwest = { workspace = true, features = ["json", "rustls-tls"] }
ring = "0.17"
syncserver-common = { path = "../syncserver-common" }
tokenserver-common = { path = "../tokenserver-common" }
tokenserver-settings = { path = "../tokenserver-settings" }
tokio = { workspace = true }
pyo3 = { version = "0.20", features = ["auto-initialize"], optional = true}
[dev-dependencies]
mockito = "0.30.0"
tokio = { workspace = true, features = ["macros"]}
[features]
default = ["py"]
py = ["pyo3"]

View File

@ -1,6 +1,7 @@
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, StatusCode};
use serde::{de::Deserializer, Deserialize, Serialize};
use syncserver_common::Metrics;
use tokenserver_common::{ErrorLocation, TokenType, TokenserverError};
use tokenserver_settings::Settings;
@ -52,7 +53,11 @@ impl VerifyToken for Verifier {
/// Verifies a BrowserID assertion. Returns `VerifyOutput` for valid assertions and a
/// `TokenserverError` for invalid assertions.
async fn verify(&self, assertion: String) -> Result<VerifyOutput, TokenserverError> {
async fn verify(
&self,
assertion: String,
_metrics: &Metrics,
) -> Result<VerifyOutput, TokenserverError> {
let response = self
.request_client
.post(&self.fxa_verifier_url)
@ -313,7 +318,10 @@ mod tests {
})
.unwrap();
let result = verifier.verify("test".to_owned()).await.unwrap();
let result = verifier
.verify("test".to_owned(), &Default::default())
.await
.unwrap();
mock.assert();
let expected_result = VerifyOutput {
@ -345,7 +353,10 @@ mod tests {
.with_header("content-type", "application/json")
.create();
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
let error = verifier
.verify(assertion.to_owned(), &Default::default())
.await
.unwrap_err();
mock.assert();
let expected_error = TokenserverError {
@ -363,7 +374,10 @@ mod tests {
.with_body("<h1>Server Error</h1>")
.create();
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
let error = verifier
.verify(assertion.to_owned(), &Default::default())
.await
.unwrap_err();
mock.assert();
let expected_error = TokenserverError {
@ -381,7 +395,10 @@ mod tests {
.with_body("{\"status\": \"error\"}")
.create();
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
let error = verifier
.verify(assertion.to_owned(), &Default::default())
.await
.unwrap_err();
mock.assert();
let expected_error = TokenserverError {
@ -399,7 +416,10 @@ mod tests {
.with_body("{\"status\": \"potato\"}")
.create();
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
let error = verifier
.verify(assertion.to_owned(), &Default::default())
.await
.unwrap_err();
mock.assert();
let expected_error = TokenserverError {
@ -417,7 +437,10 @@ mod tests {
.with_body("{\"status\": \"failure\", \"reason\": \"something broke\"}")
.create();
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
let error = verifier
.verify(assertion.to_owned(), &Default::default())
.await
.unwrap_err();
mock.assert();
let expected_error = TokenserverError {
@ -434,7 +457,10 @@ mod tests {
.with_body("{\"status\": \"failure\"}")
.create();
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
let error = verifier
.verify(assertion.to_owned(), &Default::default())
.await
.unwrap_err();
mock.assert();
let expected_error = TokenserverError {
@ -481,7 +507,10 @@ mod tests {
{
let mock = mock("login.persona.org");
let error = verifier.verify(assertion.clone()).await.unwrap_err();
let error = verifier
.verify(assertion.clone(), &Default::default())
.await
.unwrap_err();
mock.assert();
assert_eq!(expected_error, error);
@ -489,7 +518,10 @@ mod tests {
{
let mock = mock(ISSUER);
let result = verifier.verify(assertion.clone()).await.unwrap();
let result = verifier
.verify(assertion.clone(), &Default::default())
.await
.unwrap();
let expected_result = VerifyOutput {
device_id: None,
email: "test@example.com".to_owned(),
@ -503,7 +535,10 @@ mod tests {
{
let mock = mock("accounts.firefox.org");
let error = verifier.verify(assertion.clone()).await.unwrap_err();
let error = verifier
.verify(assertion.clone(), &Default::default())
.await
.unwrap_err();
mock.assert();
assert_eq!(expected_error, error);
@ -511,7 +546,10 @@ mod tests {
{
let mock = mock("http://accounts.firefox.com");
let error = verifier.verify(assertion.clone()).await.unwrap_err();
let error = verifier
.verify(assertion.clone(), &Default::default())
.await
.unwrap_err();
mock.assert();
assert_eq!(expected_error, error);
@ -519,7 +557,10 @@ mod tests {
{
let mock = mock("accounts.firefox.co");
let error = verifier.verify(assertion.clone()).await.unwrap_err();
let error = verifier
.verify(assertion.clone(), &Default::default())
.await
.unwrap_err();
mock.assert();
assert_eq!(expected_error, error);
@ -536,7 +577,10 @@ mod tests {
.with_header("content-type", "application/json")
.with_body(body.to_string())
.create();
let error = verifier.verify(assertion.clone()).await.unwrap_err();
let error = verifier
.verify(assertion.clone(), &Default::default())
.await
.unwrap_err();
mock.assert();
let expected_error = TokenserverError {
@ -558,7 +602,10 @@ mod tests {
.with_header("content-type", "application/json")
.with_body(body.to_string())
.create();
let error = verifier.verify(assertion.clone()).await.unwrap_err();
let error = verifier
.verify(assertion.clone(), &Default::default())
.await
.unwrap_err();
mock.assert();
let expected_error = TokenserverError {
@ -579,7 +626,10 @@ mod tests {
.with_header("content-type", "application/json")
.with_body(body.to_string())
.create();
let error = verifier.verify(assertion).await.unwrap_err();
let error = verifier
.verify(assertion, &Default::default())
.await
.unwrap_err();
mock.assert();
let expected_error = TokenserverError {

View File

@ -0,0 +1,172 @@
use hkdf::Hkdf;
use hmac::{Hmac, Mac};
use jsonwebtoken::{errors::ErrorKind, jwk::Jwk, Algorithm, DecodingKey, Validation};
use ring::rand::{SecureRandom, SystemRandom};
use serde::de::DeserializeOwned;
use sha2::Sha256;
use tokenserver_common::TokenserverError;
pub const SHA256_OUTPUT_LEN: usize = 32;
/// A triat representing all the required cryptographic operations by the token server
pub trait Crypto {
type Error;
/// HKDF key derivation
///
/// This expands `info` into a 32 byte value using `secret` and the optional `salt`.
/// Salt is normally specified, except when this function is called in [syncserver-settings::Secrets::new] or when deriving
/// a key to be used to sign the tokenserver tokens, so both syncserver and tokenserver can
/// sign and validate the signatures
fn hkdf(&self, secret: &str, salt: Option<&[u8]>, info: &[u8]) -> Result<Vec<u8>, Self::Error>;
/// HMAC signiture
///
/// Signs the `payload` using HMAC given the `key`
fn hmac_sign(&self, key: &[u8], payload: &[u8]) -> Result<Vec<u8>, Self::Error>;
/// Verify an HMAC signature on a payload given a shared key
fn hmac_verify(&self, key: &[u8], payload: &[u8], signature: &[u8]) -> Result<(), Self::Error>;
/// Generates random bytes using a cryptographic random number generator
/// and fills `output` with those bytes
fn rand_bytes(&self, output: &mut [u8]) -> Result<(), Self::Error>;
}
/// An implementation for the needed cryptographic using
/// the hmac crate for hmac and hkdf crate for hkdf
/// it uses ring for the random number generation
pub struct CryptoImpl {}
impl Crypto for CryptoImpl {
type Error = TokenserverError;
fn hkdf(&self, secret: &str, salt: Option<&[u8]>, info: &[u8]) -> Result<Vec<u8>, Self::Error> {
let hk = Hkdf::<Sha256>::new(salt, secret.as_bytes());
let mut okm = [0u8; SHA256_OUTPUT_LEN];
hk.expand(info, &mut okm)
.map_err(|_| TokenserverError::internal_error())?;
Ok(okm.to_vec())
}
fn hmac_sign(&self, key: &[u8], payload: &[u8]) -> Result<Vec<u8>, Self::Error> {
let mut mac: Hmac<Sha256> =
Hmac::new_from_slice(key).map_err(|_| TokenserverError::internal_error())?;
mac.update(payload);
Ok(mac.finalize().into_bytes().to_vec())
}
fn hmac_verify(&self, key: &[u8], payload: &[u8], signature: &[u8]) -> Result<(), Self::Error> {
let mut mac: Hmac<Sha256> =
Hmac::new_from_slice(key).map_err(|_| TokenserverError::internal_error())?;
mac.update(payload);
mac.verify_slice(signature)
.map_err(|_| TokenserverError::internal_error())?;
Ok(())
}
fn rand_bytes(&self, output: &mut [u8]) -> Result<(), Self::Error> {
let rng = SystemRandom::new();
rng.fill(output)
.map_err(|_| TokenserverError::internal_error())?;
Ok(())
}
}
/// OAuthVerifyError captures the errors possible while verifing an OAuth JWT access token
#[derive(Debug, thiserror::Error)]
pub enum OAuthVerifyError {
#[error("The signature has expired")]
ExpiredSignature,
#[error("Untrusted token")]
TrustError,
#[error("Invalid Key")]
InvalidKey,
#[error("Error decoding JWT")]
DecodingError,
#[error("The key was well formatted, but the signature was invalid")]
InvalidSignature,
}
impl OAuthVerifyError {
pub fn metric_label(&self) -> &'static str {
match self {
Self::ExpiredSignature => "oauth.error.expired_signature",
Self::TrustError => "oauth.error.trust_error",
Self::InvalidKey => "oauth.error.invalid_key",
Self::InvalidSignature => "oauth.error.invalid_signature",
Self::DecodingError => "oauth.error.decoding_error",
}
}
pub fn is_reportable_err(&self) -> bool {
matches!(self, Self::InvalidKey | Self::DecodingError)
}
}
impl From<jsonwebtoken::errors::Error> for OAuthVerifyError {
fn from(value: jsonwebtoken::errors::Error) -> Self {
match value.kind() {
ErrorKind::InvalidKeyFormat => OAuthVerifyError::InvalidKey,
ErrorKind::InvalidSignature => OAuthVerifyError::InvalidSignature,
ErrorKind::ExpiredSignature => OAuthVerifyError::ExpiredSignature,
_ => OAuthVerifyError::DecodingError,
}
}
}
/// A trait representing a JSON Web Token verifier <https://datatracker.ietf.org/doc/html/rfc7519>
pub trait JWTVerifier: TryFrom<Self::Key, Error = OAuthVerifyError> + Sync + Send + Clone {
type Key: DeserializeOwned;
fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<T, OAuthVerifyError>;
}
/// An implementation of the JWT verifier using the jsonwebtoken crate
#[derive(Clone)]
pub struct JWTVerifierImpl {
key: DecodingKey,
validation: Validation,
}
impl JWTVerifier for JWTVerifierImpl {
type Key = Jwk;
fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<T, OAuthVerifyError> {
let token_data = jsonwebtoken::decode::<T>(token, &self.key, &self.validation)?;
token_data
.header
.typ
.ok_or(OAuthVerifyError::TrustError)
.and_then(|typ| {
// Ref https://tools.ietf.org/html/rfc7515#section-4.1.9 the `typ` header
// is lowercase and has an implicit default `application/` prefix.
let typ = if !typ.contains('/') {
format!("application/{}", typ)
} else {
typ
};
if typ.to_lowercase() != "application/at+jwt" {
return Err(OAuthVerifyError::TrustError);
}
Ok(typ)
})?;
Ok(token_data.claims)
}
}
impl TryFrom<Jwk> for JWTVerifierImpl {
type Error = OAuthVerifyError;
fn try_from(value: Jwk) -> Result<Self, Self::Error> {
let decoding_key =
DecodingKey::from_jwk(&value).map_err(|_| OAuthVerifyError::InvalidKey)?;
let mut validation = Validation::new(Algorithm::RS256);
// The FxA OAuth ecosystem currently doesn't make good use of aud, and
// instead relies on scope for restricting which services can accept
// which tokens. So there's no value in checking it here, and in fact if
// we check it here, it fails because the right audience isn't being
// requested.
validation.validate_aud = false;
Ok(Self {
key: decoding_key,
validation,
})
}
}

View File

@ -1,17 +1,21 @@
pub mod browserid;
#[cfg(not(feature = "py"))]
mod crypto;
#[cfg(not(feature = "py"))]
pub use crypto::{JWTVerifier, JWTVerifierImpl};
pub mod oauth;
mod token;
use syncserver_common::Metrics;
pub use token::Tokenlib;
use std::fmt;
use async_trait::async_trait;
use dyn_clone::{self, DynClone};
use pyo3::{
prelude::{IntoPy, PyErr, PyModule, PyObject, Python},
types::IntoPyDict,
};
use serde::{Deserialize, Serialize};
use tokenserver_common::TokenserverError;
/// Represents the origin of the token used by Sync clients to access their data.
#[derive(Clone, Copy, Default, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
#[serde(rename_all = "lowercase")]
@ -33,7 +37,7 @@ impl fmt::Display for TokenserverOrigin {
}
/// The plaintext needed to build a token.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
pub struct MakeTokenPlaintext {
pub node: String,
pub fxa_kid: String,
@ -45,69 +49,6 @@ pub struct MakeTokenPlaintext {
pub tokenserver_origin: TokenserverOrigin,
}
impl IntoPy<PyObject> for MakeTokenPlaintext {
fn into_py(self, py: Python<'_>) -> PyObject {
let dict = [
("node", self.node),
("fxa_kid", self.fxa_kid),
("fxa_uid", self.fxa_uid),
("hashed_device_id", self.hashed_device_id),
("hashed_fxa_uid", self.hashed_fxa_uid),
("tokenserver_origin", self.tokenserver_origin.to_string()),
]
.into_py_dict(py);
// These need to be set separately since they aren't strings, and
// Rust doesn't support heterogeneous arrays
dict.set_item("expires", self.expires).unwrap();
dict.set_item("uid", self.uid).unwrap();
dict.into()
}
}
/// An adapter to the tokenlib Python library.
pub struct Tokenlib;
impl Tokenlib {
/// Builds the token and derived secret to be returned by Tokenserver.
pub fn get_token_and_derived_secret(
plaintext: MakeTokenPlaintext,
shared_secret: &str,
) -> Result<(String, String), TokenserverError> {
Python::with_gil(|py| {
// `import tokenlib`
let module = PyModule::import(py, "tokenlib").map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})?;
// `kwargs = { 'secret': shared_secret }`
let kwargs = [("secret", shared_secret)].into_py_dict(py);
// `token = tokenlib.make_token(plaintext, **kwargs)`
let token = module
.getattr("make_token")?
.call((plaintext,), Some(kwargs))
.map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})
.and_then(|x| x.extract())?;
// `derived_secret = tokenlib.get_derived_secret(token, **kwargs)`
let derived_secret = module
.getattr("get_derived_secret")?
.call((&token,), Some(kwargs))
.map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})
.and_then(|x| x.extract())?;
// `return (token, derived_secret)`
Ok((token, derived_secret))
})
.map_err(pyerr_to_tokenserver_error)
}
}
/// Implementers of this trait can be used to verify tokens for Tokenserver.
#[async_trait]
pub trait VerifyToken: DynClone + Sync + Send {
@ -115,7 +56,11 @@ pub trait VerifyToken: DynClone + Sync + Send {
/// Verifies the given token. This function is async because token verification often involves
/// making a request to a remote server.
async fn verify(&self, token: String) -> Result<Self::Output, TokenserverError>;
async fn verify(
&self,
token: String,
metrics: &Metrics,
) -> Result<Self::Output, TokenserverError>;
}
dyn_clone::clone_trait_object!(<T> VerifyToken<Output=T>);
@ -131,16 +76,9 @@ pub struct MockVerifier<T: Clone + Send + Sync> {
impl<T: Clone + Send + Sync> VerifyToken for MockVerifier<T> {
type Output = T;
async fn verify(&self, _token: String) -> Result<T, TokenserverError> {
async fn verify(&self, _token: String, _metrics: &Metrics) -> Result<T, TokenserverError> {
self.valid
.then(|| self.verify_output.clone())
.ok_or_else(|| TokenserverError::invalid_credentials("Unauthorized".to_owned()))
}
}
fn pyerr_to_tokenserver_error(e: PyErr) -> TokenserverError {
TokenserverError {
context: e.to_string(),
..TokenserverError::internal_error()
}
}

View File

@ -1,18 +1,15 @@
use async_trait::async_trait;
use pyo3::{
prelude::{Py, PyAny, PyErr, PyModule, Python},
types::{IntoPyDict, PyString},
};
use serde::{Deserialize, Serialize};
use serde_json;
use syncserver_common::BlockingThreadpool;
use tokenserver_common::TokenserverError;
use tokenserver_settings::{Jwk, Settings};
use tokio::time;
use super::VerifyToken;
#[cfg(not(feature = "py"))]
mod native;
#[cfg(feature = "py")]
mod py;
use std::{sync::Arc, time::Duration};
#[cfg(feature = "py")]
pub type Verifier = py::Verifier;
#[cfg(not(feature = "py"))]
pub type Verifier<J> = native::Verifier<J>;
/// The information extracted from a valid OAuth token.
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
@ -21,148 +18,3 @@ pub struct VerifyOutput {
pub fxa_uid: String,
pub generation: Option<i64>,
}
/// The verifier used to verify OAuth tokens.
#[derive(Clone)]
pub struct Verifier {
// Note that we do not need to use an Arc here, since Py is already a reference-counted
// pointer
inner: Py<PyAny>,
timeout: u64,
blocking_threadpool: Arc<BlockingThreadpool>,
}
impl Verifier {
const FILENAME: &'static str = "verify.py";
pub fn new(
settings: &Settings,
blocking_threadpool: Arc<BlockingThreadpool>,
) -> Result<Self, TokenserverError> {
let inner: Py<PyAny> = Python::with_gil::<_, Result<Py<PyAny>, PyErr>>(|py| {
let code = include_str!("verify.py");
let module = PyModule::from_code(py, code, Self::FILENAME, Self::FILENAME)?;
let kwargs = {
let dict = [("server_url", &settings.fxa_oauth_server_url)].into_py_dict(py);
let parse_jwk = |jwk: &Jwk| {
let dict = [
("kty", &jwk.kty),
("alg", &jwk.alg),
("kid", &jwk.kid),
("use", &jwk.use_of_key),
("n", &jwk.n),
("e", &jwk.e),
]
.into_py_dict(py);
dict.set_item("fxa-createdAt", jwk.fxa_created_at).unwrap();
dict
};
let jwks = match (
&settings.fxa_oauth_primary_jwk,
&settings.fxa_oauth_secondary_jwk,
) {
(Some(primary_jwk), Some(secondary_jwk)) => {
Some(vec![parse_jwk(primary_jwk), parse_jwk(secondary_jwk)])
}
(Some(jwk), None) | (None, Some(jwk)) => Some(vec![parse_jwk(jwk)]),
(None, None) => None,
};
dict.set_item("jwks", jwks).unwrap();
dict
};
let object: Py<PyAny> = module
.getattr("FxaOAuthClient")?
.call((), Some(kwargs))
.map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})?
.into();
Ok(object)
})
.map_err(super::pyerr_to_tokenserver_error)?;
Ok(Self {
inner,
timeout: settings.fxa_oauth_request_timeout,
blocking_threadpool,
})
}
}
#[async_trait]
impl VerifyToken for Verifier {
type Output = VerifyOutput;
/// Verifies an OAuth token. Returns `VerifyOutput` for valid tokens and a `TokenserverError`
/// for invalid tokens.
async fn verify(&self, token: String) -> Result<VerifyOutput, TokenserverError> {
// We don't want to move `self` into the body of the closure here because we'd need to
// clone it. Cloning it is only necessary if we need to verify the token remotely via FxA,
// since that would require passing `self` to a separate thread. Passing &Self to a closure
// gives us the flexibility to clone only when necessary.
let verify_inner = |verifier: &Self| {
let maybe_verify_output_string = Python::with_gil(|py| {
let client = verifier.inner.as_ref(py);
// `client.verify_token(token)`
let result: &PyAny = client
.getattr("verify_token")?
.call((token,), None)
.map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})?;
if result.is_none() {
Ok(None)
} else {
let verify_output_python_string = result.downcast::<PyString>()?;
verify_output_python_string.extract::<String>().map(Some)
}
})
.map_err(|e| TokenserverError {
context: format!("pyo3 error in OAuth verifier: {}", e),
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
})?;
match maybe_verify_output_string {
Some(verify_output_string) => {
serde_json::from_str::<VerifyOutput>(&verify_output_string).map_err(|e| {
TokenserverError {
context: format!("Invalid OAuth verify output: {}", e),
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
}
})
}
None => Err(TokenserverError {
context: "Invalid OAuth token".to_owned(),
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
}),
}
};
let verifier = self.clone();
// If the JWK is not cached or if the token is not a JWT/wasn't signed by a known key
// type, PyFxA will make a request to FxA to retrieve it, blocking this thread. To
// improve performance, we make the request on a thread in a threadpool specifically
// used for blocking operations. The JWK should _always_ be cached in production to
// maximize performance.
let fut = self
.blocking_threadpool
.spawn(move || verify_inner(&verifier));
// The PyFxA OAuth client does not offer a way to set a request timeout, so we set one here
// by timing out the future if the verification process blocks this thread for longer
// than the specified number of seconds.
time::timeout(Duration::from_secs(self.timeout), fut)
.await
.map_err(|_| TokenserverError {
context: "OAuth verification timeout".to_owned(),
..TokenserverError::resource_unavailable()
})?
}
}

View File

@ -0,0 +1,557 @@
use super::VerifyOutput;
pub use crate::crypto::JWTVerifier;
use crate::crypto::OAuthVerifyError;
use crate::VerifyToken;
use async_trait::async_trait;
use reqwest::Url;
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, time::Duration};
use syncserver_common::Metrics;
use tokenserver_common::TokenserverError;
use tokenserver_settings::Settings;
const SYNC_SCOPE: &str = "https://identity.mozilla.com/apps/oldsync";
#[derive(Serialize, Deserialize, Debug)]
struct TokenClaims {
#[serde(rename = "sub")]
user: String,
scope: String,
#[serde(rename = "fxa-generation")]
generation: Option<i64>,
}
impl TokenClaims {
fn validate(self) -> Result<VerifyOutput, TokenserverError> {
if !self.scope.split(',').any(|scope| scope == SYNC_SCOPE) {
return Err(TokenserverError::invalid_credentials(
"Unauthorized".to_string(),
));
}
Ok(self.into())
}
}
impl From<TokenClaims> for VerifyOutput {
fn from(value: TokenClaims) -> Self {
Self {
fxa_uid: value.user,
generation: value.generation,
}
}
}
/// The verifier used to verify OAuth tokens.
#[derive(Clone)]
pub struct Verifier<J> {
verify_url: Url,
jwks_url: Url,
jwk_verifiers: Vec<J>,
http_client: reqwest::Client,
}
impl<J> Verifier<J>
where
J: JWTVerifier,
{
pub fn new(settings: &Settings, jwk_verifiers: Vec<J>) -> Result<Self, TokenserverError> {
let base_url = Url::parse(&settings.fxa_oauth_server_url)
.map_err(|_| TokenserverError::internal_error())?;
let verify_url = base_url
.join("v1/verify")
.map_err(|_| TokenserverError::internal_error())?;
let jwks_url = base_url
.join("v1/jwks")
.map_err(|_| TokenserverError::internal_error())?;
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(settings.fxa_oauth_request_timeout))
.use_rustls_tls()
.build()
.map_err(|_| TokenserverError::internal_error())?;
Ok(Self {
verify_url,
jwks_url,
jwk_verifiers,
http_client,
})
}
async fn remote_verify_token(&self, token: &str) -> Result<TokenClaims, TokenserverError> {
#[derive(Serialize)]
struct VerifyRequest<'a> {
token: &'a str,
}
#[derive(Serialize, Deserialize)]
struct VerifyResponse {
user: String,
scope: Vec<String>,
generation: Option<i64>,
}
impl From<VerifyResponse> for TokenClaims {
fn from(value: VerifyResponse) -> Self {
Self {
user: value.user,
scope: value.scope.join(","),
generation: value.generation,
}
}
}
Ok(self
.http_client
.post(self.verify_url.clone())
.json(&VerifyRequest { token })
.send()
.await
.map_err(unauthorized_err_with_ctx)
.and_then(|res| {
if !res.status().is_success() {
Err(unauthorized_err_with_ctx(format!(
"Got verify status code: {}",
res.status()
)))
} else {
Ok(res)
}
})?
.json::<VerifyResponse>()
.await
.map_err(unauthorized_err_with_ctx)?
.into())
}
async fn get_remote_jwks(&self) -> Result<Vec<J>, TokenserverError> {
#[derive(Deserialize)]
struct KeysResponse<K> {
keys: Vec<K>,
}
self.http_client
.get(self.jwks_url.clone())
.send()
.await
.map_err(internal_err_with_ctx)?
.json::<KeysResponse<J::Key>>()
.await
.map_err(internal_err_with_ctx)?
.keys
.into_iter()
.map(|key| key.try_into().map_err(internal_err_with_ctx))
.collect()
}
fn verify_jwt_locally(
&self,
verifiers: &[Cow<'_, J>],
token: &str,
) -> Result<TokenClaims, OAuthVerifyError> {
if verifiers.is_empty() {
return Err(OAuthVerifyError::InvalidKey);
}
verifiers
.iter()
.find_map(|verifier| {
match verifier.verify::<TokenClaims>(token) {
// If it's an invalid signature, it means our key was well formatted,
// but the signature was incorrect. Lets try another key if we have any
Err(OAuthVerifyError::InvalidSignature) => None,
res => Some(res),
}
})
// If there is nothing, it means all of our keys were well formatted, but none of them
// were able to verify the signature, lets erturn a TrustError
.ok_or(OAuthVerifyError::TrustError)?
}
}
#[async_trait]
impl<J> VerifyToken for Verifier<J>
where
J: JWTVerifier,
{
type Output = VerifyOutput;
/// Verifies an OAuth token. Returns `VerifyOutput` for valid tokens and a `TokenserverError`
/// for invalid tokens.
///
/// The verifier will first attempt to verify the token using FxA's public keys, which were
/// provided as environment variables.
///
/// If FxA's public keys were not supplied, then the verifier will query FxA's /v1/jwks
/// endpoint to get the latest public keys.
///
/// If verifying the tokens fails because the keys are
/// invalid, or because the keys were valid but the tokens have changed their structure, then
/// the verifier will fallback to hitting fxa's /v1/verify endpoint to verify instead. All
/// other failures will be recorded as invalid credentials and will returns a generic "Unauthorized" message
/// to the user
async fn verify(
&self,
token: String,
metrics: &Metrics,
) -> Result<VerifyOutput, TokenserverError> {
let mut verifiers = self
.jwk_verifiers
.iter()
.map(Cow::Borrowed)
.collect::<Vec<_>>();
if self.jwk_verifiers.is_empty() {
verifiers = self
.get_remote_jwks()
.await
.unwrap_or_else(|e| {
slog_scope::warn!("Error requesting remote jwks: {}", e);
vec![]
})
.into_iter()
.map(Cow::Owned)
.collect();
}
let claims = match self.verify_jwt_locally(&verifiers, &token) {
Ok(res) => res,
Err(e) => {
if e.is_reportable_err() {
metrics.incr(e.metric_label())
}
match e {
OAuthVerifyError::DecodingError | OAuthVerifyError::InvalidKey => {
self.remote_verify_token(&token).await?
}
e => return Err(unauthorized_err_with_ctx(e)),
}
}
};
claims.validate()
}
}
fn unauthorized_err_with_ctx<E: std::fmt::Display>(err: E) -> TokenserverError {
TokenserverError {
context: err.to_string(),
..TokenserverError::invalid_credentials("Unauthorized".to_string())
}
}
fn internal_err_with_ctx<E: std::fmt::Display>(err: E) -> TokenserverError {
TokenserverError {
context: err.to_string(),
..TokenserverError::internal_error()
}
}
#[cfg(test)]
mod tests {
use crate::crypto::{JWTVerifierImpl, OAuthVerifyError};
use serde_json::json;
use super::*;
#[derive(Deserialize)]
struct MockJWK {}
macro_rules! mock_jwk_verifier {
($im:expr) => {
mock_jwk_verifier!(_token, $im);
};
($token:ident, $im:expr) => {
#[derive(Clone, Debug)]
struct MockJWTVerifier {}
impl TryFrom<MockJWK> for MockJWTVerifier {
type Error = OAuthVerifyError;
fn try_from(_value: MockJWK) -> Result<Self, Self::Error> {
Ok(Self {})
}
}
impl JWTVerifier for MockJWTVerifier {
type Key = MockJWK;
fn verify<T: ::serde::de::DeserializeOwned>(
&self,
$token: &str,
) -> Result<T, OAuthVerifyError> {
$im
}
}
};
}
#[tokio::test]
async fn test_no_keys_in_verifier_fallsback_to_fxa() -> Result<(), TokenserverError> {
let mock_jwks = mockito::mock("GET", "/v1/jwks").with_status(500).create();
let body = json!({
"user": "fxa_id",
"scope": [SYNC_SCOPE],
"generation": 123
});
let mock_verify = mockito::mock("POST", "/v1/verify")
.with_header("content-type", "application/json")
.with_status(200)
.with_body(body.to_string())
.create();
let settings = Settings {
fxa_oauth_server_url: mockito::server_url(),
..Default::default()
};
let verifer: Verifier<JWTVerifierImpl> = Verifier::new(&settings, vec![])?;
let res = verifer
.verify("a token fxa will validate".to_string(), &Default::default())
.await?;
mock_jwks.expect(1);
mock_verify.expect(1);
assert_eq!(res.generation.unwrap(), 123);
assert_eq!(res.fxa_uid, "fxa_id");
Ok(())
}
#[tokio::test]
async fn test_expired_signature_fails() -> Result<(), TokenserverError> {
let mock = mockito::mock("POST", "/v1/verify").create();
mock_jwk_verifier!(Err(OAuthVerifyError::InvalidSignature));
let jwk_verifiers = vec![MockJWTVerifier {}];
let settings = Settings {
fxa_oauth_server_url: mockito::server_url(),
..Settings::default()
};
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers)?;
let err = verifier
.verify("An expired token".to_string(), &Default::default())
.await
.unwrap_err();
// We also make sure we didn't try to hit the server
mock.expect(0);
assert_eq!(err.status, "invalid-credentials");
assert_eq!(err.http_status, 401);
assert_eq!(err.description, "Unauthorized");
Ok(())
}
#[tokio::test]
async fn test_verifier_attempts_all_keys_if_invalid_signature() -> Result<(), TokenserverError>
{
let mock = mockito::mock("POST", "/v1/verify").create();
#[derive(Debug, Clone)]
struct MockJWTVerifier {
id: u8,
}
impl TryFrom<MockJWK> for MockJWTVerifier {
type Error = OAuthVerifyError;
fn try_from(_value: MockJWK) -> Result<Self, Self::Error> {
Ok(Self { id: 0 })
}
}
impl JWTVerifier for MockJWTVerifier {
type Key = MockJWK;
fn verify<T: serde::de::DeserializeOwned>(
&self,
token: &str,
) -> Result<T, OAuthVerifyError> {
if self.id == 0 {
Err(OAuthVerifyError::InvalidSignature)
} else {
Ok(serde_json::from_str(token).unwrap())
}
}
}
let jwk_verifiers = vec![MockJWTVerifier { id: 0 }, MockJWTVerifier { id: 1 }];
let settings = Settings {
fxa_oauth_server_url: mockito::server_url(),
..Settings::default()
};
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
let token_claims = TokenClaims {
user: "fxa_id".to_string(),
scope: SYNC_SCOPE.to_string(),
generation: Some(124),
};
let res = verifier
.verify(
serde_json::to_string(&token_claims).unwrap(),
&Default::default(),
)
.await?;
assert_eq!(res.fxa_uid, "fxa_id");
assert_eq!(res.generation.unwrap(), 124);
mock.expect(0); // We shouldn't have hit the server
Ok(())
}
#[tokio::test]
async fn test_verifier_all_signature_failures_fails() -> Result<(), TokenserverError> {
let mock_verify = mockito::mock("POST", "/v1/verify").create();
mock_jwk_verifier!(Err(OAuthVerifyError::InvalidSignature));
let jwk_verifiers = vec![MockJWTVerifier {}, MockJWTVerifier {}];
let settings = Settings {
fxa_oauth_server_url: mockito::server_url(),
..Settings::default()
};
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
let err = verifier
.verify(
"a token with an invalid signature".to_string(),
&Default::default(),
)
.await
.unwrap_err();
assert_eq!(err.status, "invalid-credentials");
assert_eq!(err.http_status, 401);
assert_eq!(err.description, "Unauthorized");
mock_verify.expect(0);
Ok(())
}
#[tokio::test]
async fn test_verifier_fallsback_if_decode_error() -> Result<(), TokenserverError> {
let body = json!({
"user": "fxa_id",
"scope": [SYNC_SCOPE],
"generation": 123
});
let mock_verify = mockito::mock("POST", "/v1/verify")
.with_header("content-type", "application/json")
.with_status(200)
.with_body(body.to_string())
.create();
mock_jwk_verifier!(Err(OAuthVerifyError::DecodingError));
let jwk_verifiers = vec![MockJWTVerifier {}];
let settings = Settings {
fxa_oauth_server_url: mockito::server_url(),
..Settings::default()
};
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
let res = verifier
.verify(
"invalid token that can't be decoded".to_string(),
&Default::default(),
)
.await?;
assert_eq!(res.fxa_uid, "fxa_id");
assert_eq!(res.generation.unwrap(), 123);
mock_verify.expect(1); // We would have have hit the server
Ok(())
}
#[tokio::test]
async fn test_no_sync_scope_fails() -> Result<(), TokenserverError> {
let token_claims = TokenClaims {
user: "fxa_id".to_string(),
scope: "some other scope".to_string(),
generation: Some(124),
};
mock_jwk_verifier!(token, Ok(serde_json::from_str(token).unwrap()));
let jwk_verifiers = vec![MockJWTVerifier {}];
let settings = Settings {
fxa_oauth_server_url: mockito::server_url(),
..Settings::default()
};
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
let err = verifier
.verify(
serde_json::to_string(&token_claims).unwrap(),
&Default::default(),
)
.await
.unwrap_err();
assert_eq!(err.status, "invalid-credentials");
assert_eq!(err.http_status, 401);
assert_eq!(err.description, "Unauthorized");
Ok(())
}
#[tokio::test]
async fn test_fxa_rejects_token_no_matter_the_body() -> Result<(), TokenserverError> {
let body = json!({
"user": "fxa_id",
"scope": [SYNC_SCOPE],
"generation": 123
});
let mock_verify = mockito::mock("POST", "/v1/verify")
.with_header("content-type", "application/json")
.with_status(401)
// Even though the body is fine, if FxA returns a none-200, we automatically
// return a credential error
.with_body(body.to_string())
.create();
let settings = Settings {
fxa_oauth_server_url: mockito::server_url(),
..Settings::default()
};
mock_jwk_verifier!(Err(OAuthVerifyError::DecodingError));
let jwk_verifiers = vec![];
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
let err = verifier
.verify(
"A token that we will ask FxA about".to_string(),
&Default::default(),
)
.await
.unwrap_err();
assert_eq!(err.status, "invalid-credentials");
assert_eq!(err.http_status, 401);
assert_eq!(err.description, "Unauthorized");
mock_verify.expect(1);
Ok(())
}
#[tokio::test]
async fn test_fxa_accepts_token_but_bad_body() -> Result<(), TokenserverError> {
let body = json!({
"bad_key": "foo",
"scope": [SYNC_SCOPE],
"bad_genreation": 123
});
let mock_verify = mockito::mock("POST", "/v1/verify")
.with_header("content-type", "application/json")
.with_status(200)
// Even though the body is valid json, it doesn't match our expectation so we'll error
// out
.with_body(body.to_string())
.create();
let settings = Settings {
fxa_oauth_server_url: mockito::server_url(),
..Settings::default()
};
mock_jwk_verifier!(Err(OAuthVerifyError::DecodingError));
let jwk_verifiers = vec![];
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
let err = verifier
.verify(
"A token that we will ask FxA about".to_string(),
&Default::default(),
)
.await
.unwrap_err();
assert_eq!(err.status, "invalid-credentials");
assert_eq!(err.http_status, 401);
assert_eq!(err.description, "Unauthorized");
mock_verify.expect(1);
Ok(())
}
}

View File

@ -0,0 +1,195 @@
use async_trait::async_trait;
use jsonwebtoken::jwk::{AlgorithmParameters, Jwk, PublicKeyUse, RSAKeyParameters};
use pyo3::{
prelude::{Py, PyAny, PyErr, PyModule, Python},
types::{IntoPyDict, PyString},
};
use serde_json;
use syncserver_common::{BlockingThreadpool, Metrics};
use tokenserver_common::TokenserverError;
use tokenserver_settings::Settings;
use tokio::time;
use super::VerifyOutput;
use crate::VerifyToken;
use std::{sync::Arc, time::Duration};
/// The verifier used to verify OAuth tokens.
#[derive(Clone)]
pub struct Verifier {
// Note that we do not need to use an Arc here, since Py is already a reference-counted
// pointer
inner: Py<PyAny>,
timeout: u64,
blocking_threadpool: Arc<BlockingThreadpool>,
}
impl Verifier {
const FILENAME: &'static str = "verify.py";
pub fn new(
settings: &Settings,
blocking_threadpool: Arc<BlockingThreadpool>,
) -> Result<Self, TokenserverError> {
let inner: Py<PyAny> = Python::with_gil::<_, Result<Py<PyAny>, TokenserverError>>(|py| {
let code = include_str!("verify.py");
let module = PyModule::from_code(py, code, Self::FILENAME, Self::FILENAME)
.map_err(pyerr_to_tokenserver_error)?;
let kwargs = {
let dict = [("server_url", &settings.fxa_oauth_server_url)].into_py_dict(py);
let parse_jwk = |jwk: &Jwk| {
let (n, e) = match &jwk.algorithm {
AlgorithmParameters::RSA(RSAKeyParameters { key_type: _, n, e }) => (n, e),
_ => return Err(TokenserverError::internal_error()),
};
let alg = jwk
.common
.key_algorithm
.ok_or_else(TokenserverError::internal_error)?
.to_string();
let kid = jwk
.common
.key_id
.as_ref()
.ok_or_else(TokenserverError::internal_error)?;
if !matches!(
jwk.common
.public_key_use
.as_ref()
.ok_or_else(TokenserverError::internal_error)?,
PublicKeyUse::Signature
) {
return Err(TokenserverError::internal_error());
}
let dict = [
("kty", "RSA"),
("alg", &alg),
("kid", kid),
("use", "sig"),
("n", &n),
("e", e),
]
.into_py_dict(py);
Ok(dict)
};
let jwks = match (
&settings.fxa_oauth_primary_jwk,
&settings.fxa_oauth_secondary_jwk,
) {
(Some(primary_jwk), Some(secondary_jwk)) => {
Some(vec![parse_jwk(primary_jwk)?, parse_jwk(secondary_jwk)?])
}
(Some(jwk), None) | (None, Some(jwk)) => Some(vec![parse_jwk(jwk)?]),
(None, None) => None,
};
dict.set_item("jwks", jwks).unwrap();
dict
};
let object: Py<PyAny> = module
.getattr("FxaOAuthClient")
.map_err(pyerr_to_tokenserver_error)?
.call((), Some(kwargs))
.map_err(|e| {
e.print_and_set_sys_last_vars(py);
pyerr_to_tokenserver_error(e)
})?
.into();
Ok(object)
})?;
Ok(Self {
inner,
timeout: settings.fxa_oauth_request_timeout,
blocking_threadpool,
})
}
}
#[async_trait]
impl VerifyToken for Verifier {
type Output = VerifyOutput;
/// Verifies an OAuth token. Returns `VerifyOutput` for valid tokens and a `TokenserverError`
/// for invalid tokens.
async fn verify(
&self,
token: String,
_metrics: &Metrics,
) -> Result<VerifyOutput, TokenserverError> {
// We don't want to move `self` into the body of the closure here because we'd need to
// clone it. Cloning it is only necessary if we need to verify the token remotely via FxA,
// since that would require passing `self` to a separate thread. Passing &Self to a closure
// gives us the flexibility to clone only when necessary.
let verify_inner = |verifier: &Self| {
let maybe_verify_output_string = Python::with_gil(|py| {
let client = verifier.inner.as_ref(py);
// `client.verify_token(token)`
let result: &PyAny = client
.getattr("verify_token")?
.call((token,), None)
.map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})?;
if result.is_none() {
Ok(None)
} else {
let verify_output_python_string = result.downcast::<PyString>()?;
verify_output_python_string.extract::<String>().map(Some)
}
})
.map_err(|e| TokenserverError {
context: format!("pyo3 error in OAuth verifier: {}", e),
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
})?;
match maybe_verify_output_string {
Some(verify_output_string) => {
serde_json::from_str::<VerifyOutput>(&verify_output_string).map_err(|e| {
TokenserverError {
context: format!("Invalid OAuth verify output: {}", e),
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
}
})
}
None => Err(TokenserverError {
context: "Invalid OAuth token".to_owned(),
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
}),
}
};
let verifier = self.clone();
// If the JWK is not cached or if the token is not a JWT/wasn't signed by a known key
// type, PyFxA will make a request to FxA to retrieve it, blocking this thread. To
// improve performance, we make the request on a thread in a threadpool specifically
// used for blocking operations. The JWK should _always_ be cached in production to
// maximize performance.
let fut = self
.blocking_threadpool
.spawn(move || verify_inner(&verifier));
// The PyFxA OAuth client does not offer a way to set a request timeout, so we set one here
// by timing out the future if the verification process blocks this thread for longer
// than the specified number of seconds.
time::timeout(Duration::from_secs(self.timeout), fut)
.await
.map_err(|_| TokenserverError {
context: "OAuth verification timeout".to_owned(),
..TokenserverError::resource_unavailable()
})?
}
}
fn pyerr_to_tokenserver_error(e: PyErr) -> TokenserverError {
TokenserverError {
context: e.to_string(),
..TokenserverError::internal_error()
}
}

View File

@ -0,0 +1,11 @@
#[cfg(not(feature = "py"))]
mod native;
#[cfg(feature = "py")]
mod py;
#[cfg(feature = "py")]
pub type Tokenlib = py::PyTokenlib;
#[cfg(not(feature = "py"))]
pub type Tokenlib = native::Tokenlib;

View File

@ -0,0 +1,113 @@
use crate::{
crypto::{Crypto, CryptoImpl},
MakeTokenPlaintext,
};
use base64::Engine;
use serde::{Deserialize, Serialize};
use tokenserver_common::TokenserverError;
// Those two constants were pulled directly from
// https://github.com/mozilla-services/tokenlib/blob/91ec9e2c922e55306eddba1394590a88f3b10602/tokenlib/__init__.py#L43-L45
// We could change them, but we'd want to make sure that we also change them syncstorage, however
// that would cause temporary auth issues for anyone with an old pre-new-value token
const HKDF_SIGNING_INFO: &[u8] = b"services.mozilla.com/tokenlib/v1/signing";
const HKDF_INFO_DERIVE: &[u8] = b"services.mozilla.com/tokenlib/v1/derive/";
pub struct Tokenlib {}
#[derive(Debug, Serialize, Deserialize)]
struct Token<'a> {
#[serde(flatten)]
plaintext: MakeTokenPlaintext,
salt: &'a str,
}
impl Tokenlib {
pub fn get_token_and_derived_secret(
plaintext: MakeTokenPlaintext,
shared_secret: &str,
) -> Result<(String, String), TokenserverError> {
// First we make the token itself, the code blow was ported from:
// https://github.com/mozilla-services/tokenlib/blob/91ec9e2c922e55306eddba1394590a88f3b10602/tokenlib/__init__.py#L96-L97
let crypto_lib = CryptoImpl {};
let mut salt_bytes = [0u8; 3];
crypto_lib.rand_bytes(&mut salt_bytes)?;
let salt = hex::encode(salt_bytes);
let token_str = serde_json::to_string(&Token {
plaintext,
salt: &salt,
})
.map_err(|_| TokenserverError::internal_error())?;
let hmac_key = crypto_lib.hkdf(shared_secret, None, HKDF_SIGNING_INFO)?;
let signature = crypto_lib.hmac_sign(&hmac_key, token_str.as_bytes())?;
let mut token_bytes = Vec::with_capacity(token_str.len() + signature.len());
token_bytes.extend_from_slice(token_str.as_bytes());
token_bytes.extend_from_slice(&signature);
let token = base64::engine::general_purpose::URL_SAFE.encode(token_bytes);
// Now that we finialized the token, lets generate our per token secret
// The code below was ported from:
// https://github.com/mozilla-services/tokenlib/blob/91ec9e2c922e55306eddba1394590a88f3b10602/tokenlib/__init__.py#L158-L159
let mut info = Vec::with_capacity(HKDF_INFO_DERIVE.len() + token.as_bytes().len());
info.extend_from_slice(HKDF_INFO_DERIVE);
info.extend_from_slice(token.as_bytes());
let per_token_secret = crypto_lib.hkdf(shared_secret, Some(salt.as_bytes()), &info)?;
let per_token_secret = base64::engine::general_purpose::URL_SAFE.encode(per_token_secret);
Ok((token, per_token_secret))
}
}
#[cfg(test)]
mod tests {
use crate::{crypto::SHA256_OUTPUT_LEN, TokenserverOrigin};
use super::*;
#[test]
fn test_generate_valid_token_and_per_token_secret() -> Result<(), TokenserverError> {
// First we verify that the token we generated has a valid
// and correct HMAC signature if signed using the same key
let plaintext = MakeTokenPlaintext {
node: "https://www.example.com".to_string(),
fxa_kid: "kid".to_string(),
fxa_uid: "user uid".to_string(),
hashed_fxa_uid: "hased uid".to_string(),
hashed_device_id: "hashed device id".to_string(),
expires: 1031,
uid: 13,
tokenserver_origin: TokenserverOrigin::Rust,
};
let secret = "foobar";
let crypto_impl = CryptoImpl {};
let hmac_key = crypto_impl.hkdf(secret, None, HKDF_SIGNING_INFO).unwrap();
let (b64_token, per_token_secret) =
Tokenlib::get_token_and_derived_secret(plaintext.clone(), secret).unwrap();
let token = base64::engine::general_purpose::URL_SAFE
.decode(&b64_token)
.unwrap();
let token_size = token.len();
let signature = &token[token_size - SHA256_OUTPUT_LEN..];
let payload = &token[..token_size - SHA256_OUTPUT_LEN];
crypto_impl
.hmac_verify(&hmac_key, payload, signature)
.unwrap();
// Then we verify that the payload value we signed, is a valid
// Token represented by our Token struct, and has exactly the same
// plain_text values
let token_data = serde_json::from_slice::<Token<'_>>(payload).unwrap();
assert_eq!(token_data.plaintext, plaintext);
// Finally, we verify that the same per_token_secret can be derived given the payload
// and the shared secret
let mut info = Vec::with_capacity(HKDF_INFO_DERIVE.len() + b64_token.as_bytes().len());
info.extend_from_slice(HKDF_INFO_DERIVE);
info.extend_from_slice(b64_token.as_bytes());
let expected_per_token_secret =
crypto_impl.hkdf(secret, Some(token_data.salt.as_bytes()), &info)?;
let expected_per_token_secret =
base64::engine::general_purpose::URL_SAFE.encode(expected_per_token_secret);
assert_eq!(expected_per_token_secret, per_token_secret);
Ok(())
}
}

View File

@ -0,0 +1,71 @@
use crate::{MakeTokenPlaintext, TokenserverError};
use pyo3::{
prelude::{IntoPy, PyErr, PyModule, PyObject, Python},
types::IntoPyDict,
};
pub struct PyTokenlib {}
impl IntoPy<PyObject> for MakeTokenPlaintext {
fn into_py(self, py: Python<'_>) -> PyObject {
let dict = [
("node", self.node),
("fxa_kid", self.fxa_kid),
("fxa_uid", self.fxa_uid),
("hashed_device_id", self.hashed_device_id),
("hashed_fxa_uid", self.hashed_fxa_uid),
("tokenserver_origin", self.tokenserver_origin.to_string()),
]
.into_py_dict(py);
// These need to be set separately since they aren't strings, and
// Rust doesn't support heterogeneous arrays
dict.set_item("expires", self.expires).unwrap();
dict.set_item("uid", self.uid).unwrap();
dict.into()
}
}
impl PyTokenlib {
pub fn get_token_and_derived_secret(
plaintext: MakeTokenPlaintext,
shared_secret: &str,
) -> Result<(String, String), TokenserverError> {
Python::with_gil(|py| {
// `import tokenlib`
let module = PyModule::import(py, "tokenlib").map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})?;
// `kwargs = { 'secret': shared_secret }`
let kwargs = [("secret", shared_secret)].into_py_dict(py);
// `token = tokenlib.make_token(plaintext, **kwargs)`
let token = module
.getattr("make_token")?
.call((plaintext,), Some(kwargs))
.map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})
.and_then(|x| x.extract())?;
// `derived_secret = tokenlib.get_derived_secret(token, **kwargs)`
let derived_secret = module
.getattr("get_derived_secret")?
.call((&token,), Some(kwargs))
.map_err(|e| {
e.print_and_set_sys_last_vars(py);
e
})
.and_then(|x| x.extract())?;
// `return (token, derived_secret)`
Ok((token, derived_secret))
})
.map_err(pyerr_to_tokenserver_error)
}
}
fn pyerr_to_tokenserver_error(e: PyErr) -> TokenserverError {
TokenserverError {
context: e.to_string(),
..TokenserverError::internal_error()
}
}

View File

@ -10,6 +10,7 @@ actix-web.workspace = true
backtrace.workspace = true
serde.workspace = true
serde_json.workspace = true
jsonwebtoken.workspace = true
thiserror.workspace = true
syncserver-common = { path = "../syncserver-common" }
thiserror = "1.0.26"

View File

@ -13,6 +13,7 @@ serde.workspace = true
serde_derive.workspace = true
serde_json.workspace = true
slog-scope.workspace = true
thiserror.workspace = true
async-trait = "0.1.40"
diesel = { version = "1.4", features = ["mysql", "r2d2"] }
@ -20,7 +21,6 @@ diesel_logger = "0.1.1"
diesel_migrations = { version = "1.4.0", features = ["mysql"] }
syncserver-common = { path = "../syncserver-common" }
syncserver-db-common = { path = "../syncserver-db-common" }
thiserror = "1.0.26"
tokenserver-common = { path = "../tokenserver-common" }
tokenserver-settings = { path = "../tokenserver-settings" }
tokio = { workspace = true, features = ["macros", "sync"] }

View File

@ -7,5 +7,6 @@ edition.workspace=true
[dependencies]
serde.workspace=true
jsonwebtoken.workspace=true
tokenserver-common = { path = "../tokenserver-common" }

View File

@ -1,3 +1,4 @@
use jsonwebtoken::jwk::Jwk;
use serde::Deserialize;
use tokenserver_common::NodeType;
@ -69,18 +70,6 @@ pub struct Settings {
pub token_duration: u64,
}
#[derive(Clone, Debug, Deserialize)]
pub struct Jwk {
pub kty: String,
pub alg: String,
pub kid: String,
pub fxa_created_at: u64,
#[serde(rename = "use")]
pub use_of_key: String,
pub n: String,
pub e: String,
}
impl Default for Settings {
fn default() -> Settings {
Settings {

View File

@ -215,14 +215,14 @@ class TestE2e(TestCase, unittest.TestCase):
raw = urlsafe_b64decode(res.json['id'])
payload = raw[:-32]
signature = raw[-32:]
payload_dict = json.loads(payload.decode('utf-8'))
payload_str = payload.decode('utf-8')
payload_dict = json.loads(payload_str)
# The `id` payload should include a field indicating the origin of the
# token
self.assertEqual(payload_dict['tokenserver_origin'], 'rust')
signing_secret = self.TOKEN_SIGNING_SECRET
expected_token = tokenlib.make_token(payload_dict,
secret=signing_secret)
expected_signature = urlsafe_b64decode(expected_token)[-32:]
tm = tokenlib.TokenManager(secret=signing_secret)
expected_signature = tm._get_signature(payload_str.encode('utf8'))
# Using the #compare_digest method here is not strictly necessary, as
# this is not a security-sensitive situation, but it's good practice
self.assertTrue(hmac.compare_digest(expected_signature, signature))
@ -271,12 +271,11 @@ class TestE2e(TestCase, unittest.TestCase):
raw = urlsafe_b64decode(res.json['id'])
payload = raw[:-32]
signature = raw[-32:]
payload_dict = json.loads(payload.decode('utf-8'))
payload_str = payload.decode('utf-8')
signing_secret = self.TOKEN_SIGNING_SECRET
expected_token = tokenlib.make_token(payload_dict,
secret=signing_secret)
expected_signature = urlsafe_b64decode(expected_token)[-32:]
tm = tokenlib.TokenManager(secret=signing_secret)
expected_signature = tm._get_signature(payload_str.encode('utf8'))
# Using the #compare_digest method here is not strictly necessary, as
# this is not a security-sensitive situation, but it's good practice
self.assertTrue(hmac.compare_digest(expected_signature, signature))