mirror of
https://github.com/mozilla-services/syncstorage-rs.git
synced 2026-05-05 04:06:16 +02:00
feat: Puts pyo3 behind feature flag and derives tokens directly in Rust (#1513)
* Removes pyo3 and derives tokens directly in Rust * Adds tests for JWT verifying * Adds tests for token generation * Adds metrics for oauth verify error cases * Updates jsonwebtoken to not include default features (including pem loading) * Adds context and logs errors during oauth verify * Uses ring for cryptographic rng * Adds back python impl under feature flag * Uses one cached http client for reqwest
This commit is contained in:
parent
d544a0e378
commit
1b11684648
@ -48,13 +48,13 @@ commands:
|
||||
- run:
|
||||
name: Rust Clippy MySQL
|
||||
command: |
|
||||
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql -- -D warnings
|
||||
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- -D warnings
|
||||
rust-clippy-spanner:
|
||||
steps:
|
||||
- run:
|
||||
name: Rust Clippy Spanner
|
||||
command: |
|
||||
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner -- -D warnings
|
||||
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- -D warnings
|
||||
cargo-build:
|
||||
steps:
|
||||
- run:
|
||||
|
||||
26
Cargo.lock
generated
26
Cargo.lock
generated
@ -1068,8 +1068,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"wasi",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1439,6 +1441,19 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonwebtoken"
|
||||
version = "9.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c7ea04a7c5c055c175f189b6dc6ba036fd62306b58c66c9f6389036c503a3f4"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"js-sys",
|
||||
"ring",
|
||||
"serde 1.0.195",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "language-tags"
|
||||
version = "0.3.2"
|
||||
@ -3032,14 +3047,23 @@ name = "tokenserver-auth"
|
||||
version = "0.14.4"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64",
|
||||
"dyn-clone",
|
||||
"futures 0.3.30",
|
||||
"hex",
|
||||
"hkdf",
|
||||
"hmac",
|
||||
"jsonwebtoken",
|
||||
"mockito",
|
||||
"pyo3",
|
||||
"reqwest",
|
||||
"ring",
|
||||
"serde 1.0.195",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"slog-scope",
|
||||
"syncserver-common",
|
||||
"thiserror",
|
||||
"tokenserver-common",
|
||||
"tokenserver-settings",
|
||||
"tokio",
|
||||
@ -3051,6 +3075,7 @@ version = "0.14.4"
|
||||
dependencies = [
|
||||
"actix-web",
|
||||
"backtrace",
|
||||
"jsonwebtoken",
|
||||
"serde 1.0.195",
|
||||
"serde_json",
|
||||
"syncserver-common",
|
||||
@ -3086,6 +3111,7 @@ dependencies = [
|
||||
name = "tokenserver-settings"
|
||||
version = "0.14.4"
|
||||
dependencies = [
|
||||
"jsonwebtoken",
|
||||
"serde 1.0.195",
|
||||
"tokenserver-common",
|
||||
]
|
||||
|
||||
@ -38,7 +38,10 @@ docopt = "1.1"
|
||||
env_logger = "0.10"
|
||||
futures = { version = "0.3", features = ["compat"] }
|
||||
hex = "0.4"
|
||||
hkdf = "0.12"
|
||||
hmac = "0.12"
|
||||
http = "0.2"
|
||||
jsonwebtoken = { version = "9.2", default-features = false }
|
||||
lazy_static = "1.4"
|
||||
protobuf = "=2.25.2" # pin to 2.25.2 to prevent side updating
|
||||
rand = "0.8"
|
||||
@ -62,6 +65,7 @@ slog-scope = "4.3"
|
||||
slog-stdlog = "4.1"
|
||||
slog-term = "2.6"
|
||||
tokio = "1"
|
||||
thiserror = "1.0.26"
|
||||
|
||||
[profile.release]
|
||||
# Enables line numbers in Sentry reporting
|
||||
|
||||
@ -19,7 +19,7 @@ RUN \
|
||||
apt-get -q install -y --no-install-recommends libmysqlclient-dev cmake
|
||||
|
||||
COPY --from=planner /app/recipe.json recipe.json
|
||||
RUN cargo chef cook --release --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --recipe-path recipe.json
|
||||
RUN cargo chef cook --release --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --features=py_verifier --recipe-path recipe.json
|
||||
|
||||
FROM chef as builder
|
||||
ARG DATABASE_BACKEND=spanner
|
||||
@ -46,7 +46,7 @@ ENV PATH=$PATH:/root/.cargo/bin
|
||||
RUN \
|
||||
cargo --version && \
|
||||
rustc --version && \
|
||||
cargo install --path ./syncserver --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --locked --root /app && \
|
||||
cargo install --path ./syncserver --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --features=py_verifier --locked --root /app && \
|
||||
if [ "$DATABASE_BACKEND" = "spanner" ] ; then cargo install --path ./syncstorage-spanner --locked --root /app --bin purge_ttl ; fi
|
||||
|
||||
FROM docker.io/library/debian:bullseye-slim
|
||||
@ -56,11 +56,12 @@ COPY --from=builder /app/requirements.txt /app
|
||||
# have to set this env var to prevent the cryptography package from building
|
||||
# with Rust. See this link for more information:
|
||||
# https://pythonshowcase.com/question/problem-installing-cryptography-on-raspberry-pi
|
||||
ENV CRYPTOGRAPHY_DONT_BUILD_RUST=1
|
||||
|
||||
RUN \
|
||||
apt-get -q update && apt-get -qy install wget
|
||||
|
||||
ENV CRYPTOGRAPHY_DONT_BUILD_RUST=1
|
||||
|
||||
RUN \
|
||||
groupadd --gid 10001 app && \
|
||||
useradd --uid 10001 --gid 10001 --home /app --create-home app && \
|
||||
|
||||
11
Makefile
11
Makefile
@ -15,11 +15,11 @@ PYTHON_SITE_PACKGES = $(shell $(SRC_ROOT)/venv/bin/python -c "from distutils.sys
|
||||
|
||||
clippy_mysql:
|
||||
# Matches what's run in circleci
|
||||
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql -- -D warnings
|
||||
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- -D warnings
|
||||
|
||||
clippy_spanner:
|
||||
# Matches what's run in circleci
|
||||
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner -- -D warnings
|
||||
cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- -D warnings
|
||||
|
||||
clean:
|
||||
cargo clean
|
||||
@ -47,14 +47,15 @@ python:
|
||||
python3 -m venv venv
|
||||
venv/bin/python -m pip install -r requirements.txt
|
||||
|
||||
|
||||
run_mysql: python
|
||||
PATH="./venv/bin:$(PATH)" \
|
||||
# See https://github.com/PyO3/pyo3/issues/1741 for discussion re: why we need to set the
|
||||
# below env var
|
||||
PYTHONPATH=$(PYTHON_SITE_PACKGES) \
|
||||
RUST_LOG=debug \
|
||||
RUST_LOG=debug \
|
||||
RUST_BACKTRACE=full \
|
||||
cargo run --no-default-features --features=syncstorage-db/mysql -- --config config/local.toml
|
||||
cargo run --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- --config config/local.toml
|
||||
|
||||
run_spanner: python
|
||||
GOOGLE_APPLICATION_CREDENTIALS=$(PATH_TO_SYNC_SPANNER_KEYS) \
|
||||
@ -65,7 +66,7 @@ run_spanner: python
|
||||
PATH="./venv/bin:$(PATH)" \
|
||||
RUST_LOG=debug \
|
||||
RUST_BACKTRACE=full \
|
||||
cargo run --no-default-features --features=syncstorage-db/spanner -- --config config/local.toml
|
||||
cargo run --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- --config config/local.toml
|
||||
|
||||
test:
|
||||
SYNC_SYNCSTORAGE__DATABASE_URL=mysql://sample_user:sample_password@localhost/syncstorage_rs \
|
||||
|
||||
@ -14,5 +14,5 @@ serde_json.workspace = true
|
||||
slog.workspace = true
|
||||
slog-scope.workspace = true
|
||||
actix-web.workspace = true
|
||||
hkdf.workspace = true
|
||||
|
||||
hkdf = "0.12"
|
||||
|
||||
@ -9,9 +9,9 @@ edition.workspace=true
|
||||
backtrace.workspace=true
|
||||
futures.workspace=true
|
||||
http.workspace=true
|
||||
thiserror.workspace=true
|
||||
|
||||
deadpool = { git = "https://github.com/mozilla-services/deadpool", tag = "deadpool-v0.7.0" }
|
||||
diesel = { version = "1.4", features = ["mysql", "r2d2"] }
|
||||
diesel_migrations = { version = "1.4.0", features = ["mysql"] }
|
||||
syncserver-common = { path = "../syncserver-common" }
|
||||
thiserror = "1.0.26"
|
||||
|
||||
@ -32,6 +32,8 @@ slog-mozlog-json.workspace = true
|
||||
slog-scope.workspace = true
|
||||
slog-stdlog.workspace = true
|
||||
slog-term.workspace = true
|
||||
hmac.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
||||
actix-http = "3"
|
||||
actix-rt = "2"
|
||||
@ -40,7 +42,6 @@ async-trait = "0.1.40"
|
||||
dyn-clone = "1.0.4"
|
||||
hostname = "0.3.1"
|
||||
hawk = "5.0"
|
||||
hmac = "0.12"
|
||||
mime = "0.3"
|
||||
reqwest = { workspace = true, features = [
|
||||
"json",
|
||||
@ -53,8 +54,7 @@ syncserver-settings = { path = "../syncserver-settings" }
|
||||
syncstorage-db = { path = "../syncstorage-db" }
|
||||
syncstorage-settings = { path = "../syncstorage-settings" }
|
||||
time = "^0.3"
|
||||
thiserror = "1.0.26"
|
||||
tokenserver-auth = { path = "../tokenserver-auth" }
|
||||
tokenserver-auth = { path = "../tokenserver-auth", default-features = false}
|
||||
tokenserver-common = { path = "../tokenserver-common" }
|
||||
tokenserver-db = { path = "../tokenserver-db" }
|
||||
tokenserver-settings = { path = "../tokenserver-settings" }
|
||||
@ -65,7 +65,8 @@ validator_derive = "0.16"
|
||||
woothee = "0.13"
|
||||
|
||||
[features]
|
||||
default = ["mysql"]
|
||||
default = ["mysql", "py_verifier"]
|
||||
no_auth = []
|
||||
py_verifier = ["tokenserver-auth/py"]
|
||||
mysql = ["syncstorage-db/mysql"]
|
||||
spanner = ["syncstorage-db/spanner"]
|
||||
|
||||
@ -457,7 +457,8 @@ impl FromRequest for AuthData {
|
||||
let mut tags = HashMap::default();
|
||||
tags.insert("token_type".to_owned(), "BrowserID".to_owned());
|
||||
metrics.start_timer("token_verification", Some(tags));
|
||||
let verify_output = state.browserid_verifier.verify(assertion).await?;
|
||||
let verify_output =
|
||||
state.browserid_verifier.verify(assertion, &metrics).await?;
|
||||
|
||||
// For requests using BrowserID, the client state is embedded in the
|
||||
// X-Client-State header, and the generation and keys_changed_at are extracted
|
||||
@ -487,7 +488,7 @@ impl FromRequest for AuthData {
|
||||
let mut tags = HashMap::default();
|
||||
tags.insert("token_type".to_owned(), "OAuth".to_owned());
|
||||
metrics.start_timer("token_verification", Some(tags));
|
||||
let verify_output = state.oauth_verifier.verify(token).await?;
|
||||
let verify_output = state.oauth_verifier.verify(token, &metrics).await?;
|
||||
|
||||
// For requests using OAuth, the keys_changed_at and client state are embedded
|
||||
// in the X-KeyID header.
|
||||
|
||||
@ -9,6 +9,8 @@ use serde::{
|
||||
Serialize,
|
||||
};
|
||||
use syncserver_common::{BlockingThreadpool, Metrics};
|
||||
#[cfg(not(feature = "py_verifier"))]
|
||||
use tokenserver_auth::JWTVerifierImpl;
|
||||
use tokenserver_auth::{browserid, oauth, VerifyToken};
|
||||
use tokenserver_common::NodeType;
|
||||
use tokenserver_db::{params, DbPool, TokenserverPool};
|
||||
@ -40,6 +42,32 @@ impl ServerState {
|
||||
metrics: Arc<StatsdClient>,
|
||||
blocking_threadpool: Arc<BlockingThreadpool>,
|
||||
) -> Result<Self, ApiError> {
|
||||
#[cfg(not(feature = "py_verifier"))]
|
||||
let oauth_verifier = {
|
||||
let mut jwk_verifiers: Vec<JWTVerifierImpl> = Vec::new();
|
||||
if let Some(primary) = &settings.fxa_oauth_primary_jwk {
|
||||
jwk_verifiers.push(
|
||||
primary
|
||||
.clone()
|
||||
.try_into()
|
||||
.expect("Invalid primary key, should either be fixed or removed"),
|
||||
)
|
||||
}
|
||||
if let Some(secondary) = &settings.fxa_oauth_secondary_jwk {
|
||||
jwk_verifiers.push(
|
||||
secondary
|
||||
.clone()
|
||||
.try_into()
|
||||
.expect("Invalid secondary key, should either be fixed or removed"),
|
||||
);
|
||||
}
|
||||
Box::new(
|
||||
oauth::Verifier::new(settings, jwk_verifiers)
|
||||
.expect("failed to create Tokenserver OAuth verifier"),
|
||||
)
|
||||
};
|
||||
|
||||
#[cfg(feature = "py_verifier")]
|
||||
let oauth_verifier = Box::new(
|
||||
oauth::Verifier::new(settings, blocking_threadpool.clone())
|
||||
.expect("failed to create Tokenserver OAuth verifier"),
|
||||
|
||||
@ -13,10 +13,10 @@ lazy_static.workspace=true
|
||||
http.workspace=true
|
||||
serde.workspace=true
|
||||
serde_json.workspace=true
|
||||
thiserror.workspace=true
|
||||
|
||||
async-trait = "0.1.40"
|
||||
diesel = { version = "1.4", features = ["mysql", "r2d2"] }
|
||||
diesel_migrations = { version = "1.4.0", features = ["mysql"] }
|
||||
syncserver-common = { path = "../syncserver-common" }
|
||||
syncserver-db-common = { path = "../syncserver-db-common" }
|
||||
thiserror = "1.0.26"
|
||||
|
||||
@ -11,6 +11,7 @@ base64.workspace=true
|
||||
futures.workspace=true
|
||||
http.workspace=true
|
||||
slog-scope.workspace=true
|
||||
thiserror.workspace=true
|
||||
|
||||
async-trait = "0.1.40"
|
||||
diesel = { version = "1.4", features = ["mysql", "r2d2"] }
|
||||
@ -20,7 +21,6 @@ syncserver-common = { path = "../syncserver-common" }
|
||||
syncserver-db-common = { path = "../syncserver-db-common" }
|
||||
syncstorage-db-common = { path = "../syncstorage-db-common" }
|
||||
syncstorage-settings = { path = "../syncstorage-settings" }
|
||||
thiserror = "1.0.26"
|
||||
url = "2.1"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@ -12,6 +12,7 @@ env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
http.workspace = true
|
||||
slog-scope.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
||||
async-trait = "0.1.40"
|
||||
google-cloud-rust-raw = { version = "0.16.1", features = ["spanner"] }
|
||||
@ -30,7 +31,6 @@ syncserver-common = { path = "../syncserver-common" }
|
||||
syncserver-db-common = { path = "../syncserver-db-common" }
|
||||
syncstorage-db-common = { path = "../syncstorage-db-common" }
|
||||
syncstorage-settings = { path = "../syncstorage-settings" }
|
||||
thiserror = "1.0.26"
|
||||
tokio = { workspace = true, features = [
|
||||
"macros",
|
||||
"sync",
|
||||
|
||||
@ -11,15 +11,30 @@ edition.workspace = true
|
||||
futures.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
hex.workspace = true
|
||||
hkdf.workspace = true
|
||||
hmac.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
base64.workspace = true
|
||||
sha2.workspace = true
|
||||
thiserror.workspace = true
|
||||
slog-scope.workspace = true
|
||||
|
||||
async-trait = "0.1.40"
|
||||
dyn-clone = "1.0.4"
|
||||
pyo3 = { version = "0.20", features = ["auto-initialize"] }
|
||||
reqwest = { workspace = true, features = ["json", "rustls-tls"] }
|
||||
ring = "0.17"
|
||||
syncserver-common = { path = "../syncserver-common" }
|
||||
tokenserver-common = { path = "../tokenserver-common" }
|
||||
tokenserver-settings = { path = "../tokenserver-settings" }
|
||||
tokio = { workspace = true }
|
||||
pyo3 = { version = "0.20", features = ["auto-initialize"], optional = true}
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "0.30.0"
|
||||
tokio = { workspace = true, features = ["macros"]}
|
||||
|
||||
[features]
|
||||
default = ["py"]
|
||||
py = ["pyo3"]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client as ReqwestClient, StatusCode};
|
||||
use serde::{de::Deserializer, Deserialize, Serialize};
|
||||
use syncserver_common::Metrics;
|
||||
use tokenserver_common::{ErrorLocation, TokenType, TokenserverError};
|
||||
use tokenserver_settings::Settings;
|
||||
|
||||
@ -52,7 +53,11 @@ impl VerifyToken for Verifier {
|
||||
|
||||
/// Verifies a BrowserID assertion. Returns `VerifyOutput` for valid assertions and a
|
||||
/// `TokenserverError` for invalid assertions.
|
||||
async fn verify(&self, assertion: String) -> Result<VerifyOutput, TokenserverError> {
|
||||
async fn verify(
|
||||
&self,
|
||||
assertion: String,
|
||||
_metrics: &Metrics,
|
||||
) -> Result<VerifyOutput, TokenserverError> {
|
||||
let response = self
|
||||
.request_client
|
||||
.post(&self.fxa_verifier_url)
|
||||
@ -313,7 +318,10 @@ mod tests {
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = verifier.verify("test".to_owned()).await.unwrap();
|
||||
let result = verifier
|
||||
.verify("test".to_owned(), &Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
mock.assert();
|
||||
|
||||
let expected_result = VerifyOutput {
|
||||
@ -345,7 +353,10 @@ mod tests {
|
||||
.with_header("content-type", "application/json")
|
||||
.create();
|
||||
|
||||
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.to_owned(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
mock.assert();
|
||||
|
||||
let expected_error = TokenserverError {
|
||||
@ -363,7 +374,10 @@ mod tests {
|
||||
.with_body("<h1>Server Error</h1>")
|
||||
.create();
|
||||
|
||||
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.to_owned(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
mock.assert();
|
||||
|
||||
let expected_error = TokenserverError {
|
||||
@ -381,7 +395,10 @@ mod tests {
|
||||
.with_body("{\"status\": \"error\"}")
|
||||
.create();
|
||||
|
||||
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.to_owned(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
mock.assert();
|
||||
|
||||
let expected_error = TokenserverError {
|
||||
@ -399,7 +416,10 @@ mod tests {
|
||||
.with_body("{\"status\": \"potato\"}")
|
||||
.create();
|
||||
|
||||
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.to_owned(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
mock.assert();
|
||||
|
||||
let expected_error = TokenserverError {
|
||||
@ -417,7 +437,10 @@ mod tests {
|
||||
.with_body("{\"status\": \"failure\", \"reason\": \"something broke\"}")
|
||||
.create();
|
||||
|
||||
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.to_owned(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
mock.assert();
|
||||
|
||||
let expected_error = TokenserverError {
|
||||
@ -434,7 +457,10 @@ mod tests {
|
||||
.with_body("{\"status\": \"failure\"}")
|
||||
.create();
|
||||
|
||||
let error = verifier.verify(assertion.to_owned()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.to_owned(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
mock.assert();
|
||||
|
||||
let expected_error = TokenserverError {
|
||||
@ -481,7 +507,10 @@ mod tests {
|
||||
|
||||
{
|
||||
let mock = mock("login.persona.org");
|
||||
let error = verifier.verify(assertion.clone()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.clone(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
mock.assert();
|
||||
assert_eq!(expected_error, error);
|
||||
@ -489,7 +518,10 @@ mod tests {
|
||||
|
||||
{
|
||||
let mock = mock(ISSUER);
|
||||
let result = verifier.verify(assertion.clone()).await.unwrap();
|
||||
let result = verifier
|
||||
.verify(assertion.clone(), &Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let expected_result = VerifyOutput {
|
||||
device_id: None,
|
||||
email: "test@example.com".to_owned(),
|
||||
@ -503,7 +535,10 @@ mod tests {
|
||||
|
||||
{
|
||||
let mock = mock("accounts.firefox.org");
|
||||
let error = verifier.verify(assertion.clone()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.clone(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
mock.assert();
|
||||
assert_eq!(expected_error, error);
|
||||
@ -511,7 +546,10 @@ mod tests {
|
||||
|
||||
{
|
||||
let mock = mock("http://accounts.firefox.com");
|
||||
let error = verifier.verify(assertion.clone()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.clone(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
mock.assert();
|
||||
assert_eq!(expected_error, error);
|
||||
@ -519,7 +557,10 @@ mod tests {
|
||||
|
||||
{
|
||||
let mock = mock("accounts.firefox.co");
|
||||
let error = verifier.verify(assertion.clone()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.clone(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
mock.assert();
|
||||
assert_eq!(expected_error, error);
|
||||
@ -536,7 +577,10 @@ mod tests {
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(body.to_string())
|
||||
.create();
|
||||
let error = verifier.verify(assertion.clone()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.clone(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
mock.assert();
|
||||
|
||||
let expected_error = TokenserverError {
|
||||
@ -558,7 +602,10 @@ mod tests {
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(body.to_string())
|
||||
.create();
|
||||
let error = verifier.verify(assertion.clone()).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion.clone(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
mock.assert();
|
||||
|
||||
let expected_error = TokenserverError {
|
||||
@ -579,7 +626,10 @@ mod tests {
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(body.to_string())
|
||||
.create();
|
||||
let error = verifier.verify(assertion).await.unwrap_err();
|
||||
let error = verifier
|
||||
.verify(assertion, &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
mock.assert();
|
||||
|
||||
let expected_error = TokenserverError {
|
||||
|
||||
172
tokenserver-auth/src/crypto.rs
Normal file
172
tokenserver-auth/src/crypto.rs
Normal file
@ -0,0 +1,172 @@
|
||||
use hkdf::Hkdf;
|
||||
use hmac::{Hmac, Mac};
|
||||
use jsonwebtoken::{errors::ErrorKind, jwk::Jwk, Algorithm, DecodingKey, Validation};
|
||||
use ring::rand::{SecureRandom, SystemRandom};
|
||||
use serde::de::DeserializeOwned;
|
||||
use sha2::Sha256;
|
||||
use tokenserver_common::TokenserverError;
|
||||
pub const SHA256_OUTPUT_LEN: usize = 32;
|
||||
/// A triat representing all the required cryptographic operations by the token server
|
||||
pub trait Crypto {
|
||||
type Error;
|
||||
/// HKDF key derivation
|
||||
///
|
||||
/// This expands `info` into a 32 byte value using `secret` and the optional `salt`.
|
||||
/// Salt is normally specified, except when this function is called in [syncserver-settings::Secrets::new] or when deriving
|
||||
/// a key to be used to sign the tokenserver tokens, so both syncserver and tokenserver can
|
||||
/// sign and validate the signatures
|
||||
fn hkdf(&self, secret: &str, salt: Option<&[u8]>, info: &[u8]) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// HMAC signiture
|
||||
///
|
||||
/// Signs the `payload` using HMAC given the `key`
|
||||
fn hmac_sign(&self, key: &[u8], payload: &[u8]) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Verify an HMAC signature on a payload given a shared key
|
||||
fn hmac_verify(&self, key: &[u8], payload: &[u8], signature: &[u8]) -> Result<(), Self::Error>;
|
||||
|
||||
/// Generates random bytes using a cryptographic random number generator
|
||||
/// and fills `output` with those bytes
|
||||
fn rand_bytes(&self, output: &mut [u8]) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
/// An implementation for the needed cryptographic using
|
||||
/// the hmac crate for hmac and hkdf crate for hkdf
|
||||
/// it uses ring for the random number generation
|
||||
pub struct CryptoImpl {}
|
||||
|
||||
impl Crypto for CryptoImpl {
|
||||
type Error = TokenserverError;
|
||||
fn hkdf(&self, secret: &str, salt: Option<&[u8]>, info: &[u8]) -> Result<Vec<u8>, Self::Error> {
|
||||
let hk = Hkdf::<Sha256>::new(salt, secret.as_bytes());
|
||||
let mut okm = [0u8; SHA256_OUTPUT_LEN];
|
||||
hk.expand(info, &mut okm)
|
||||
.map_err(|_| TokenserverError::internal_error())?;
|
||||
Ok(okm.to_vec())
|
||||
}
|
||||
|
||||
fn hmac_sign(&self, key: &[u8], payload: &[u8]) -> Result<Vec<u8>, Self::Error> {
|
||||
let mut mac: Hmac<Sha256> =
|
||||
Hmac::new_from_slice(key).map_err(|_| TokenserverError::internal_error())?;
|
||||
mac.update(payload);
|
||||
Ok(mac.finalize().into_bytes().to_vec())
|
||||
}
|
||||
|
||||
fn hmac_verify(&self, key: &[u8], payload: &[u8], signature: &[u8]) -> Result<(), Self::Error> {
|
||||
let mut mac: Hmac<Sha256> =
|
||||
Hmac::new_from_slice(key).map_err(|_| TokenserverError::internal_error())?;
|
||||
mac.update(payload);
|
||||
mac.verify_slice(signature)
|
||||
.map_err(|_| TokenserverError::internal_error())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rand_bytes(&self, output: &mut [u8]) -> Result<(), Self::Error> {
|
||||
let rng = SystemRandom::new();
|
||||
rng.fill(output)
|
||||
.map_err(|_| TokenserverError::internal_error())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// OAuthVerifyError captures the errors possible while verifing an OAuth JWT access token
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OAuthVerifyError {
|
||||
#[error("The signature has expired")]
|
||||
ExpiredSignature,
|
||||
#[error("Untrusted token")]
|
||||
TrustError,
|
||||
#[error("Invalid Key")]
|
||||
InvalidKey,
|
||||
#[error("Error decoding JWT")]
|
||||
DecodingError,
|
||||
#[error("The key was well formatted, but the signature was invalid")]
|
||||
InvalidSignature,
|
||||
}
|
||||
|
||||
impl OAuthVerifyError {
|
||||
pub fn metric_label(&self) -> &'static str {
|
||||
match self {
|
||||
Self::ExpiredSignature => "oauth.error.expired_signature",
|
||||
Self::TrustError => "oauth.error.trust_error",
|
||||
Self::InvalidKey => "oauth.error.invalid_key",
|
||||
Self::InvalidSignature => "oauth.error.invalid_signature",
|
||||
Self::DecodingError => "oauth.error.decoding_error",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_reportable_err(&self) -> bool {
|
||||
matches!(self, Self::InvalidKey | Self::DecodingError)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<jsonwebtoken::errors::Error> for OAuthVerifyError {
|
||||
fn from(value: jsonwebtoken::errors::Error) -> Self {
|
||||
match value.kind() {
|
||||
ErrorKind::InvalidKeyFormat => OAuthVerifyError::InvalidKey,
|
||||
ErrorKind::InvalidSignature => OAuthVerifyError::InvalidSignature,
|
||||
ErrorKind::ExpiredSignature => OAuthVerifyError::ExpiredSignature,
|
||||
_ => OAuthVerifyError::DecodingError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait representing a JSON Web Token verifier <https://datatracker.ietf.org/doc/html/rfc7519>
|
||||
pub trait JWTVerifier: TryFrom<Self::Key, Error = OAuthVerifyError> + Sync + Send + Clone {
|
||||
type Key: DeserializeOwned;
|
||||
|
||||
fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<T, OAuthVerifyError>;
|
||||
}
|
||||
|
||||
/// An implementation of the JWT verifier using the jsonwebtoken crate
|
||||
#[derive(Clone)]
|
||||
pub struct JWTVerifierImpl {
|
||||
key: DecodingKey,
|
||||
validation: Validation,
|
||||
}
|
||||
|
||||
impl JWTVerifier for JWTVerifierImpl {
|
||||
type Key = Jwk;
|
||||
|
||||
fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<T, OAuthVerifyError> {
|
||||
let token_data = jsonwebtoken::decode::<T>(token, &self.key, &self.validation)?;
|
||||
token_data
|
||||
.header
|
||||
.typ
|
||||
.ok_or(OAuthVerifyError::TrustError)
|
||||
.and_then(|typ| {
|
||||
// Ref https://tools.ietf.org/html/rfc7515#section-4.1.9 the `typ` header
|
||||
// is lowercase and has an implicit default `application/` prefix.
|
||||
let typ = if !typ.contains('/') {
|
||||
format!("application/{}", typ)
|
||||
} else {
|
||||
typ
|
||||
};
|
||||
if typ.to_lowercase() != "application/at+jwt" {
|
||||
return Err(OAuthVerifyError::TrustError);
|
||||
}
|
||||
Ok(typ)
|
||||
})?;
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Jwk> for JWTVerifierImpl {
|
||||
type Error = OAuthVerifyError;
|
||||
fn try_from(value: Jwk) -> Result<Self, Self::Error> {
|
||||
let decoding_key =
|
||||
DecodingKey::from_jwk(&value).map_err(|_| OAuthVerifyError::InvalidKey)?;
|
||||
let mut validation = Validation::new(Algorithm::RS256);
|
||||
// The FxA OAuth ecosystem currently doesn't make good use of aud, and
|
||||
// instead relies on scope for restricting which services can accept
|
||||
// which tokens. So there's no value in checking it here, and in fact if
|
||||
// we check it here, it fails because the right audience isn't being
|
||||
// requested.
|
||||
validation.validate_aud = false;
|
||||
|
||||
Ok(Self {
|
||||
key: decoding_key,
|
||||
validation,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,17 +1,21 @@
|
||||
pub mod browserid;
|
||||
|
||||
#[cfg(not(feature = "py"))]
|
||||
mod crypto;
|
||||
|
||||
#[cfg(not(feature = "py"))]
|
||||
pub use crypto::{JWTVerifier, JWTVerifierImpl};
|
||||
pub mod oauth;
|
||||
mod token;
|
||||
use syncserver_common::Metrics;
|
||||
pub use token::Tokenlib;
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dyn_clone::{self, DynClone};
|
||||
use pyo3::{
|
||||
prelude::{IntoPy, PyErr, PyModule, PyObject, Python},
|
||||
types::IntoPyDict,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokenserver_common::TokenserverError;
|
||||
|
||||
/// Represents the origin of the token used by Sync clients to access their data.
|
||||
#[derive(Clone, Copy, Default, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
@ -33,7 +37,7 @@ impl fmt::Display for TokenserverOrigin {
|
||||
}
|
||||
|
||||
/// The plaintext needed to build a token.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
|
||||
pub struct MakeTokenPlaintext {
|
||||
pub node: String,
|
||||
pub fxa_kid: String,
|
||||
@ -45,69 +49,6 @@ pub struct MakeTokenPlaintext {
|
||||
pub tokenserver_origin: TokenserverOrigin,
|
||||
}
|
||||
|
||||
impl IntoPy<PyObject> for MakeTokenPlaintext {
|
||||
fn into_py(self, py: Python<'_>) -> PyObject {
|
||||
let dict = [
|
||||
("node", self.node),
|
||||
("fxa_kid", self.fxa_kid),
|
||||
("fxa_uid", self.fxa_uid),
|
||||
("hashed_device_id", self.hashed_device_id),
|
||||
("hashed_fxa_uid", self.hashed_fxa_uid),
|
||||
("tokenserver_origin", self.tokenserver_origin.to_string()),
|
||||
]
|
||||
.into_py_dict(py);
|
||||
|
||||
// These need to be set separately since they aren't strings, and
|
||||
// Rust doesn't support heterogeneous arrays
|
||||
dict.set_item("expires", self.expires).unwrap();
|
||||
dict.set_item("uid", self.uid).unwrap();
|
||||
|
||||
dict.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// An adapter to the tokenlib Python library.
|
||||
pub struct Tokenlib;
|
||||
|
||||
impl Tokenlib {
|
||||
/// Builds the token and derived secret to be returned by Tokenserver.
|
||||
pub fn get_token_and_derived_secret(
|
||||
plaintext: MakeTokenPlaintext,
|
||||
shared_secret: &str,
|
||||
) -> Result<(String, String), TokenserverError> {
|
||||
Python::with_gil(|py| {
|
||||
// `import tokenlib`
|
||||
let module = PyModule::import(py, "tokenlib").map_err(|e| {
|
||||
e.print_and_set_sys_last_vars(py);
|
||||
e
|
||||
})?;
|
||||
// `kwargs = { 'secret': shared_secret }`
|
||||
let kwargs = [("secret", shared_secret)].into_py_dict(py);
|
||||
// `token = tokenlib.make_token(plaintext, **kwargs)`
|
||||
let token = module
|
||||
.getattr("make_token")?
|
||||
.call((plaintext,), Some(kwargs))
|
||||
.map_err(|e| {
|
||||
e.print_and_set_sys_last_vars(py);
|
||||
e
|
||||
})
|
||||
.and_then(|x| x.extract())?;
|
||||
// `derived_secret = tokenlib.get_derived_secret(token, **kwargs)`
|
||||
let derived_secret = module
|
||||
.getattr("get_derived_secret")?
|
||||
.call((&token,), Some(kwargs))
|
||||
.map_err(|e| {
|
||||
e.print_and_set_sys_last_vars(py);
|
||||
e
|
||||
})
|
||||
.and_then(|x| x.extract())?;
|
||||
// `return (token, derived_secret)`
|
||||
Ok((token, derived_secret))
|
||||
})
|
||||
.map_err(pyerr_to_tokenserver_error)
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementers of this trait can be used to verify tokens for Tokenserver.
|
||||
#[async_trait]
|
||||
pub trait VerifyToken: DynClone + Sync + Send {
|
||||
@ -115,7 +56,11 @@ pub trait VerifyToken: DynClone + Sync + Send {
|
||||
|
||||
/// Verifies the given token. This function is async because token verification often involves
|
||||
/// making a request to a remote server.
|
||||
async fn verify(&self, token: String) -> Result<Self::Output, TokenserverError>;
|
||||
async fn verify(
|
||||
&self,
|
||||
token: String,
|
||||
metrics: &Metrics,
|
||||
) -> Result<Self::Output, TokenserverError>;
|
||||
}
|
||||
|
||||
dyn_clone::clone_trait_object!(<T> VerifyToken<Output=T>);
|
||||
@ -131,16 +76,9 @@ pub struct MockVerifier<T: Clone + Send + Sync> {
|
||||
impl<T: Clone + Send + Sync> VerifyToken for MockVerifier<T> {
|
||||
type Output = T;
|
||||
|
||||
async fn verify(&self, _token: String) -> Result<T, TokenserverError> {
|
||||
async fn verify(&self, _token: String, _metrics: &Metrics) -> Result<T, TokenserverError> {
|
||||
self.valid
|
||||
.then(|| self.verify_output.clone())
|
||||
.ok_or_else(|| TokenserverError::invalid_credentials("Unauthorized".to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
fn pyerr_to_tokenserver_error(e: PyErr) -> TokenserverError {
|
||||
TokenserverError {
|
||||
context: e.to_string(),
|
||||
..TokenserverError::internal_error()
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,18 +1,15 @@
|
||||
use async_trait::async_trait;
|
||||
use pyo3::{
|
||||
prelude::{Py, PyAny, PyErr, PyModule, Python},
|
||||
types::{IntoPyDict, PyString},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use syncserver_common::BlockingThreadpool;
|
||||
use tokenserver_common::TokenserverError;
|
||||
use tokenserver_settings::{Jwk, Settings};
|
||||
use tokio::time;
|
||||
|
||||
use super::VerifyToken;
|
||||
#[cfg(not(feature = "py"))]
|
||||
mod native;
|
||||
#[cfg(feature = "py")]
|
||||
mod py;
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
#[cfg(feature = "py")]
|
||||
pub type Verifier = py::Verifier;
|
||||
|
||||
#[cfg(not(feature = "py"))]
|
||||
pub type Verifier<J> = native::Verifier<J>;
|
||||
|
||||
/// The information extracted from a valid OAuth token.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
|
||||
@ -21,148 +18,3 @@ pub struct VerifyOutput {
|
||||
pub fxa_uid: String,
|
||||
pub generation: Option<i64>,
|
||||
}
|
||||
|
||||
/// The verifier used to verify OAuth tokens.
|
||||
#[derive(Clone)]
|
||||
pub struct Verifier {
|
||||
// Note that we do not need to use an Arc here, since Py is already a reference-counted
|
||||
// pointer
|
||||
inner: Py<PyAny>,
|
||||
timeout: u64,
|
||||
blocking_threadpool: Arc<BlockingThreadpool>,
|
||||
}
|
||||
|
||||
impl Verifier {
|
||||
const FILENAME: &'static str = "verify.py";
|
||||
|
||||
pub fn new(
|
||||
settings: &Settings,
|
||||
blocking_threadpool: Arc<BlockingThreadpool>,
|
||||
) -> Result<Self, TokenserverError> {
|
||||
let inner: Py<PyAny> = Python::with_gil::<_, Result<Py<PyAny>, PyErr>>(|py| {
|
||||
let code = include_str!("verify.py");
|
||||
let module = PyModule::from_code(py, code, Self::FILENAME, Self::FILENAME)?;
|
||||
let kwargs = {
|
||||
let dict = [("server_url", &settings.fxa_oauth_server_url)].into_py_dict(py);
|
||||
let parse_jwk = |jwk: &Jwk| {
|
||||
let dict = [
|
||||
("kty", &jwk.kty),
|
||||
("alg", &jwk.alg),
|
||||
("kid", &jwk.kid),
|
||||
("use", &jwk.use_of_key),
|
||||
("n", &jwk.n),
|
||||
("e", &jwk.e),
|
||||
]
|
||||
.into_py_dict(py);
|
||||
dict.set_item("fxa-createdAt", jwk.fxa_created_at).unwrap();
|
||||
|
||||
dict
|
||||
};
|
||||
|
||||
let jwks = match (
|
||||
&settings.fxa_oauth_primary_jwk,
|
||||
&settings.fxa_oauth_secondary_jwk,
|
||||
) {
|
||||
(Some(primary_jwk), Some(secondary_jwk)) => {
|
||||
Some(vec![parse_jwk(primary_jwk), parse_jwk(secondary_jwk)])
|
||||
}
|
||||
(Some(jwk), None) | (None, Some(jwk)) => Some(vec![parse_jwk(jwk)]),
|
||||
(None, None) => None,
|
||||
};
|
||||
dict.set_item("jwks", jwks).unwrap();
|
||||
dict
|
||||
};
|
||||
let object: Py<PyAny> = module
|
||||
.getattr("FxaOAuthClient")?
|
||||
.call((), Some(kwargs))
|
||||
.map_err(|e| {
|
||||
e.print_and_set_sys_last_vars(py);
|
||||
e
|
||||
})?
|
||||
.into();
|
||||
|
||||
Ok(object)
|
||||
})
|
||||
.map_err(super::pyerr_to_tokenserver_error)?;
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
timeout: settings.fxa_oauth_request_timeout,
|
||||
blocking_threadpool,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VerifyToken for Verifier {
|
||||
type Output = VerifyOutput;
|
||||
|
||||
/// Verifies an OAuth token. Returns `VerifyOutput` for valid tokens and a `TokenserverError`
|
||||
/// for invalid tokens.
|
||||
async fn verify(&self, token: String) -> Result<VerifyOutput, TokenserverError> {
|
||||
// We don't want to move `self` into the body of the closure here because we'd need to
|
||||
// clone it. Cloning it is only necessary if we need to verify the token remotely via FxA,
|
||||
// since that would require passing `self` to a separate thread. Passing &Self to a closure
|
||||
// gives us the flexibility to clone only when necessary.
|
||||
let verify_inner = |verifier: &Self| {
|
||||
let maybe_verify_output_string = Python::with_gil(|py| {
|
||||
let client = verifier.inner.as_ref(py);
|
||||
// `client.verify_token(token)`
|
||||
let result: &PyAny = client
|
||||
.getattr("verify_token")?
|
||||
.call((token,), None)
|
||||
.map_err(|e| {
|
||||
e.print_and_set_sys_last_vars(py);
|
||||
e
|
||||
})?;
|
||||
|
||||
if result.is_none() {
|
||||
Ok(None)
|
||||
} else {
|
||||
let verify_output_python_string = result.downcast::<PyString>()?;
|
||||
verify_output_python_string.extract::<String>().map(Some)
|
||||
}
|
||||
})
|
||||
.map_err(|e| TokenserverError {
|
||||
context: format!("pyo3 error in OAuth verifier: {}", e),
|
||||
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
|
||||
})?;
|
||||
|
||||
match maybe_verify_output_string {
|
||||
Some(verify_output_string) => {
|
||||
serde_json::from_str::<VerifyOutput>(&verify_output_string).map_err(|e| {
|
||||
TokenserverError {
|
||||
context: format!("Invalid OAuth verify output: {}", e),
|
||||
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
|
||||
}
|
||||
})
|
||||
}
|
||||
None => Err(TokenserverError {
|
||||
context: "Invalid OAuth token".to_owned(),
|
||||
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
|
||||
}),
|
||||
}
|
||||
};
|
||||
|
||||
let verifier = self.clone();
|
||||
|
||||
// If the JWK is not cached or if the token is not a JWT/wasn't signed by a known key
|
||||
// type, PyFxA will make a request to FxA to retrieve it, blocking this thread. To
|
||||
// improve performance, we make the request on a thread in a threadpool specifically
|
||||
// used for blocking operations. The JWK should _always_ be cached in production to
|
||||
// maximize performance.
|
||||
let fut = self
|
||||
.blocking_threadpool
|
||||
.spawn(move || verify_inner(&verifier));
|
||||
|
||||
// The PyFxA OAuth client does not offer a way to set a request timeout, so we set one here
|
||||
// by timing out the future if the verification process blocks this thread for longer
|
||||
// than the specified number of seconds.
|
||||
time::timeout(Duration::from_secs(self.timeout), fut)
|
||||
.await
|
||||
.map_err(|_| TokenserverError {
|
||||
context: "OAuth verification timeout".to_owned(),
|
||||
..TokenserverError::resource_unavailable()
|
||||
})?
|
||||
}
|
||||
}
|
||||
|
||||
557
tokenserver-auth/src/oauth/native.rs
Normal file
557
tokenserver-auth/src/oauth/native.rs
Normal file
@ -0,0 +1,557 @@
|
||||
use super::VerifyOutput;
|
||||
pub use crate::crypto::JWTVerifier;
|
||||
use crate::crypto::OAuthVerifyError;
|
||||
use crate::VerifyToken;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{borrow::Cow, time::Duration};
|
||||
use syncserver_common::Metrics;
|
||||
use tokenserver_common::TokenserverError;
|
||||
use tokenserver_settings::Settings;
|
||||
|
||||
const SYNC_SCOPE: &str = "https://identity.mozilla.com/apps/oldsync";
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct TokenClaims {
|
||||
#[serde(rename = "sub")]
|
||||
user: String,
|
||||
scope: String,
|
||||
#[serde(rename = "fxa-generation")]
|
||||
generation: Option<i64>,
|
||||
}
|
||||
|
||||
impl TokenClaims {
|
||||
fn validate(self) -> Result<VerifyOutput, TokenserverError> {
|
||||
if !self.scope.split(',').any(|scope| scope == SYNC_SCOPE) {
|
||||
return Err(TokenserverError::invalid_credentials(
|
||||
"Unauthorized".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(self.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TokenClaims> for VerifyOutput {
|
||||
fn from(value: TokenClaims) -> Self {
|
||||
Self {
|
||||
fxa_uid: value.user,
|
||||
generation: value.generation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The verifier used to verify OAuth tokens.
|
||||
#[derive(Clone)]
|
||||
pub struct Verifier<J> {
|
||||
verify_url: Url,
|
||||
jwks_url: Url,
|
||||
jwk_verifiers: Vec<J>,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl<J> Verifier<J>
|
||||
where
|
||||
J: JWTVerifier,
|
||||
{
|
||||
pub fn new(settings: &Settings, jwk_verifiers: Vec<J>) -> Result<Self, TokenserverError> {
|
||||
let base_url = Url::parse(&settings.fxa_oauth_server_url)
|
||||
.map_err(|_| TokenserverError::internal_error())?;
|
||||
let verify_url = base_url
|
||||
.join("v1/verify")
|
||||
.map_err(|_| TokenserverError::internal_error())?;
|
||||
let jwks_url = base_url
|
||||
.join("v1/jwks")
|
||||
.map_err(|_| TokenserverError::internal_error())?;
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(settings.fxa_oauth_request_timeout))
|
||||
.use_rustls_tls()
|
||||
.build()
|
||||
.map_err(|_| TokenserverError::internal_error())?;
|
||||
|
||||
Ok(Self {
|
||||
verify_url,
|
||||
jwks_url,
|
||||
jwk_verifiers,
|
||||
http_client,
|
||||
})
|
||||
}
|
||||
|
||||
async fn remote_verify_token(&self, token: &str) -> Result<TokenClaims, TokenserverError> {
|
||||
#[derive(Serialize)]
|
||||
struct VerifyRequest<'a> {
|
||||
token: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct VerifyResponse {
|
||||
user: String,
|
||||
scope: Vec<String>,
|
||||
generation: Option<i64>,
|
||||
}
|
||||
|
||||
impl From<VerifyResponse> for TokenClaims {
|
||||
fn from(value: VerifyResponse) -> Self {
|
||||
Self {
|
||||
user: value.user,
|
||||
scope: value.scope.join(","),
|
||||
generation: value.generation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self
|
||||
.http_client
|
||||
.post(self.verify_url.clone())
|
||||
.json(&VerifyRequest { token })
|
||||
.send()
|
||||
.await
|
||||
.map_err(unauthorized_err_with_ctx)
|
||||
.and_then(|res| {
|
||||
if !res.status().is_success() {
|
||||
Err(unauthorized_err_with_ctx(format!(
|
||||
"Got verify status code: {}",
|
||||
res.status()
|
||||
)))
|
||||
} else {
|
||||
Ok(res)
|
||||
}
|
||||
})?
|
||||
.json::<VerifyResponse>()
|
||||
.await
|
||||
.map_err(unauthorized_err_with_ctx)?
|
||||
.into())
|
||||
}
|
||||
|
||||
async fn get_remote_jwks(&self) -> Result<Vec<J>, TokenserverError> {
|
||||
#[derive(Deserialize)]
|
||||
struct KeysResponse<K> {
|
||||
keys: Vec<K>,
|
||||
}
|
||||
self.http_client
|
||||
.get(self.jwks_url.clone())
|
||||
.send()
|
||||
.await
|
||||
.map_err(internal_err_with_ctx)?
|
||||
.json::<KeysResponse<J::Key>>()
|
||||
.await
|
||||
.map_err(internal_err_with_ctx)?
|
||||
.keys
|
||||
.into_iter()
|
||||
.map(|key| key.try_into().map_err(internal_err_with_ctx))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn verify_jwt_locally(
|
||||
&self,
|
||||
verifiers: &[Cow<'_, J>],
|
||||
token: &str,
|
||||
) -> Result<TokenClaims, OAuthVerifyError> {
|
||||
if verifiers.is_empty() {
|
||||
return Err(OAuthVerifyError::InvalidKey);
|
||||
}
|
||||
|
||||
verifiers
|
||||
.iter()
|
||||
.find_map(|verifier| {
|
||||
match verifier.verify::<TokenClaims>(token) {
|
||||
// If it's an invalid signature, it means our key was well formatted,
|
||||
// but the signature was incorrect. Lets try another key if we have any
|
||||
Err(OAuthVerifyError::InvalidSignature) => None,
|
||||
res => Some(res),
|
||||
}
|
||||
})
|
||||
// If there is nothing, it means all of our keys were well formatted, but none of them
|
||||
// were able to verify the signature, lets erturn a TrustError
|
||||
.ok_or(OAuthVerifyError::TrustError)?
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<J> VerifyToken for Verifier<J>
|
||||
where
|
||||
J: JWTVerifier,
|
||||
{
|
||||
type Output = VerifyOutput;
|
||||
|
||||
/// Verifies an OAuth token. Returns `VerifyOutput` for valid tokens and a `TokenserverError`
|
||||
/// for invalid tokens.
|
||||
///
|
||||
/// The verifier will first attempt to verify the token using FxA's public keys, which were
|
||||
/// provided as environment variables.
|
||||
///
|
||||
/// If FxA's public keys were not supplied, then the verifier will query FxA's /v1/jwks
|
||||
/// endpoint to get the latest public keys.
|
||||
///
|
||||
/// If verifying the tokens fails because the keys are
|
||||
/// invalid, or because the keys were valid but the tokens have changed their structure, then
|
||||
/// the verifier will fallback to hitting fxa's /v1/verify endpoint to verify instead. All
|
||||
/// other failures will be recorded as invalid credentials and will returns a generic "Unauthorized" message
|
||||
/// to the user
|
||||
async fn verify(
|
||||
&self,
|
||||
token: String,
|
||||
metrics: &Metrics,
|
||||
) -> Result<VerifyOutput, TokenserverError> {
|
||||
let mut verifiers = self
|
||||
.jwk_verifiers
|
||||
.iter()
|
||||
.map(Cow::Borrowed)
|
||||
.collect::<Vec<_>>();
|
||||
if self.jwk_verifiers.is_empty() {
|
||||
verifiers = self
|
||||
.get_remote_jwks()
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
slog_scope::warn!("Error requesting remote jwks: {}", e);
|
||||
vec![]
|
||||
})
|
||||
.into_iter()
|
||||
.map(Cow::Owned)
|
||||
.collect();
|
||||
}
|
||||
|
||||
let claims = match self.verify_jwt_locally(&verifiers, &token) {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
if e.is_reportable_err() {
|
||||
metrics.incr(e.metric_label())
|
||||
}
|
||||
match e {
|
||||
OAuthVerifyError::DecodingError | OAuthVerifyError::InvalidKey => {
|
||||
self.remote_verify_token(&token).await?
|
||||
}
|
||||
e => return Err(unauthorized_err_with_ctx(e)),
|
||||
}
|
||||
}
|
||||
};
|
||||
claims.validate()
|
||||
}
|
||||
}
|
||||
|
||||
fn unauthorized_err_with_ctx<E: std::fmt::Display>(err: E) -> TokenserverError {
|
||||
TokenserverError {
|
||||
context: err.to_string(),
|
||||
..TokenserverError::invalid_credentials("Unauthorized".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn internal_err_with_ctx<E: std::fmt::Display>(err: E) -> TokenserverError {
|
||||
TokenserverError {
|
||||
context: err.to_string(),
|
||||
..TokenserverError::internal_error()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::crypto::{JWTVerifierImpl, OAuthVerifyError};
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
#[derive(Deserialize)]
|
||||
struct MockJWK {}
|
||||
|
||||
macro_rules! mock_jwk_verifier {
|
||||
($im:expr) => {
|
||||
mock_jwk_verifier!(_token, $im);
|
||||
};
|
||||
($token:ident, $im:expr) => {
|
||||
#[derive(Clone, Debug)]
|
||||
struct MockJWTVerifier {}
|
||||
impl TryFrom<MockJWK> for MockJWTVerifier {
|
||||
type Error = OAuthVerifyError;
|
||||
fn try_from(_value: MockJWK) -> Result<Self, Self::Error> {
|
||||
Ok(Self {})
|
||||
}
|
||||
}
|
||||
|
||||
impl JWTVerifier for MockJWTVerifier {
|
||||
type Key = MockJWK;
|
||||
fn verify<T: ::serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
$token: &str,
|
||||
) -> Result<T, OAuthVerifyError> {
|
||||
$im
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_keys_in_verifier_fallsback_to_fxa() -> Result<(), TokenserverError> {
|
||||
let mock_jwks = mockito::mock("GET", "/v1/jwks").with_status(500).create();
|
||||
|
||||
let body = json!({
|
||||
"user": "fxa_id",
|
||||
"scope": [SYNC_SCOPE],
|
||||
"generation": 123
|
||||
});
|
||||
let mock_verify = mockito::mock("POST", "/v1/verify")
|
||||
.with_header("content-type", "application/json")
|
||||
.with_status(200)
|
||||
.with_body(body.to_string())
|
||||
.create();
|
||||
|
||||
let settings = Settings {
|
||||
fxa_oauth_server_url: mockito::server_url(),
|
||||
..Default::default()
|
||||
};
|
||||
let verifer: Verifier<JWTVerifierImpl> = Verifier::new(&settings, vec![])?;
|
||||
let res = verifer
|
||||
.verify("a token fxa will validate".to_string(), &Default::default())
|
||||
.await?;
|
||||
mock_jwks.expect(1);
|
||||
mock_verify.expect(1);
|
||||
assert_eq!(res.generation.unwrap(), 123);
|
||||
assert_eq!(res.fxa_uid, "fxa_id");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_expired_signature_fails() -> Result<(), TokenserverError> {
|
||||
let mock = mockito::mock("POST", "/v1/verify").create();
|
||||
mock_jwk_verifier!(Err(OAuthVerifyError::InvalidSignature));
|
||||
|
||||
let jwk_verifiers = vec![MockJWTVerifier {}];
|
||||
let settings = Settings {
|
||||
fxa_oauth_server_url: mockito::server_url(),
|
||||
..Settings::default()
|
||||
};
|
||||
|
||||
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers)?;
|
||||
|
||||
let err = verifier
|
||||
.verify("An expired token".to_string(), &Default::default())
|
||||
.await
|
||||
.unwrap_err();
|
||||
// We also make sure we didn't try to hit the server
|
||||
mock.expect(0);
|
||||
assert_eq!(err.status, "invalid-credentials");
|
||||
assert_eq!(err.http_status, 401);
|
||||
assert_eq!(err.description, "Unauthorized");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_verifier_attempts_all_keys_if_invalid_signature() -> Result<(), TokenserverError>
|
||||
{
|
||||
let mock = mockito::mock("POST", "/v1/verify").create();
|
||||
#[derive(Debug, Clone)]
|
||||
struct MockJWTVerifier {
|
||||
id: u8,
|
||||
}
|
||||
|
||||
impl TryFrom<MockJWK> for MockJWTVerifier {
|
||||
type Error = OAuthVerifyError;
|
||||
fn try_from(_value: MockJWK) -> Result<Self, Self::Error> {
|
||||
Ok(Self { id: 0 })
|
||||
}
|
||||
}
|
||||
|
||||
impl JWTVerifier for MockJWTVerifier {
|
||||
type Key = MockJWK;
|
||||
fn verify<T: serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
token: &str,
|
||||
) -> Result<T, OAuthVerifyError> {
|
||||
if self.id == 0 {
|
||||
Err(OAuthVerifyError::InvalidSignature)
|
||||
} else {
|
||||
Ok(serde_json::from_str(token).unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let jwk_verifiers = vec![MockJWTVerifier { id: 0 }, MockJWTVerifier { id: 1 }];
|
||||
let settings = Settings {
|
||||
fxa_oauth_server_url: mockito::server_url(),
|
||||
..Settings::default()
|
||||
};
|
||||
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
|
||||
|
||||
let token_claims = TokenClaims {
|
||||
user: "fxa_id".to_string(),
|
||||
scope: SYNC_SCOPE.to_string(),
|
||||
generation: Some(124),
|
||||
};
|
||||
|
||||
let res = verifier
|
||||
.verify(
|
||||
serde_json::to_string(&token_claims).unwrap(),
|
||||
&Default::default(),
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(res.fxa_uid, "fxa_id");
|
||||
assert_eq!(res.generation.unwrap(), 124);
|
||||
mock.expect(0); // We shouldn't have hit the server
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_verifier_all_signature_failures_fails() -> Result<(), TokenserverError> {
|
||||
let mock_verify = mockito::mock("POST", "/v1/verify").create();
|
||||
mock_jwk_verifier!(Err(OAuthVerifyError::InvalidSignature));
|
||||
|
||||
let jwk_verifiers = vec![MockJWTVerifier {}, MockJWTVerifier {}];
|
||||
let settings = Settings {
|
||||
fxa_oauth_server_url: mockito::server_url(),
|
||||
..Settings::default()
|
||||
};
|
||||
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
|
||||
let err = verifier
|
||||
.verify(
|
||||
"a token with an invalid signature".to_string(),
|
||||
&Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.status, "invalid-credentials");
|
||||
assert_eq!(err.http_status, 401);
|
||||
assert_eq!(err.description, "Unauthorized");
|
||||
|
||||
mock_verify.expect(0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_verifier_fallsback_if_decode_error() -> Result<(), TokenserverError> {
|
||||
let body = json!({
|
||||
"user": "fxa_id",
|
||||
"scope": [SYNC_SCOPE],
|
||||
"generation": 123
|
||||
});
|
||||
let mock_verify = mockito::mock("POST", "/v1/verify")
|
||||
.with_header("content-type", "application/json")
|
||||
.with_status(200)
|
||||
.with_body(body.to_string())
|
||||
.create();
|
||||
|
||||
mock_jwk_verifier!(Err(OAuthVerifyError::DecodingError));
|
||||
|
||||
let jwk_verifiers = vec![MockJWTVerifier {}];
|
||||
let settings = Settings {
|
||||
fxa_oauth_server_url: mockito::server_url(),
|
||||
..Settings::default()
|
||||
};
|
||||
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
|
||||
|
||||
let res = verifier
|
||||
.verify(
|
||||
"invalid token that can't be decoded".to_string(),
|
||||
&Default::default(),
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(res.fxa_uid, "fxa_id");
|
||||
assert_eq!(res.generation.unwrap(), 123);
|
||||
mock_verify.expect(1); // We would have have hit the server
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_sync_scope_fails() -> Result<(), TokenserverError> {
|
||||
let token_claims = TokenClaims {
|
||||
user: "fxa_id".to_string(),
|
||||
scope: "some other scope".to_string(),
|
||||
generation: Some(124),
|
||||
};
|
||||
mock_jwk_verifier!(token, Ok(serde_json::from_str(token).unwrap()));
|
||||
let jwk_verifiers = vec![MockJWTVerifier {}];
|
||||
let settings = Settings {
|
||||
fxa_oauth_server_url: mockito::server_url(),
|
||||
..Settings::default()
|
||||
};
|
||||
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
|
||||
let err = verifier
|
||||
.verify(
|
||||
serde_json::to_string(&token_claims).unwrap(),
|
||||
&Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.status, "invalid-credentials");
|
||||
assert_eq!(err.http_status, 401);
|
||||
assert_eq!(err.description, "Unauthorized");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fxa_rejects_token_no_matter_the_body() -> Result<(), TokenserverError> {
|
||||
let body = json!({
|
||||
"user": "fxa_id",
|
||||
"scope": [SYNC_SCOPE],
|
||||
"generation": 123
|
||||
});
|
||||
let mock_verify = mockito::mock("POST", "/v1/verify")
|
||||
.with_header("content-type", "application/json")
|
||||
.with_status(401)
|
||||
// Even though the body is fine, if FxA returns a none-200, we automatically
|
||||
// return a credential error
|
||||
.with_body(body.to_string())
|
||||
.create();
|
||||
let settings = Settings {
|
||||
fxa_oauth_server_url: mockito::server_url(),
|
||||
..Settings::default()
|
||||
};
|
||||
|
||||
mock_jwk_verifier!(Err(OAuthVerifyError::DecodingError));
|
||||
let jwk_verifiers = vec![];
|
||||
|
||||
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
|
||||
|
||||
let err = verifier
|
||||
.verify(
|
||||
"A token that we will ask FxA about".to_string(),
|
||||
&Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.status, "invalid-credentials");
|
||||
assert_eq!(err.http_status, 401);
|
||||
assert_eq!(err.description, "Unauthorized");
|
||||
mock_verify.expect(1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fxa_accepts_token_but_bad_body() -> Result<(), TokenserverError> {
|
||||
let body = json!({
|
||||
"bad_key": "foo",
|
||||
"scope": [SYNC_SCOPE],
|
||||
"bad_genreation": 123
|
||||
});
|
||||
let mock_verify = mockito::mock("POST", "/v1/verify")
|
||||
.with_header("content-type", "application/json")
|
||||
.with_status(200)
|
||||
// Even though the body is valid json, it doesn't match our expectation so we'll error
|
||||
// out
|
||||
.with_body(body.to_string())
|
||||
.create();
|
||||
let settings = Settings {
|
||||
fxa_oauth_server_url: mockito::server_url(),
|
||||
..Settings::default()
|
||||
};
|
||||
|
||||
mock_jwk_verifier!(Err(OAuthVerifyError::DecodingError));
|
||||
let jwk_verifiers = vec![];
|
||||
|
||||
let verifier: Verifier<MockJWTVerifier> = Verifier::new(&settings, jwk_verifiers).unwrap();
|
||||
|
||||
let err = verifier
|
||||
.verify(
|
||||
"A token that we will ask FxA about".to_string(),
|
||||
&Default::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.status, "invalid-credentials");
|
||||
assert_eq!(err.http_status, 401);
|
||||
assert_eq!(err.description, "Unauthorized");
|
||||
mock_verify.expect(1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
195
tokenserver-auth/src/oauth/py.rs
Normal file
195
tokenserver-auth/src/oauth/py.rs
Normal file
@ -0,0 +1,195 @@
|
||||
use async_trait::async_trait;
|
||||
use jsonwebtoken::jwk::{AlgorithmParameters, Jwk, PublicKeyUse, RSAKeyParameters};
|
||||
use pyo3::{
|
||||
prelude::{Py, PyAny, PyErr, PyModule, Python},
|
||||
types::{IntoPyDict, PyString},
|
||||
};
|
||||
use serde_json;
|
||||
use syncserver_common::{BlockingThreadpool, Metrics};
|
||||
use tokenserver_common::TokenserverError;
|
||||
use tokenserver_settings::Settings;
|
||||
use tokio::time;
|
||||
|
||||
use super::VerifyOutput;
|
||||
use crate::VerifyToken;
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
/// The verifier used to verify OAuth tokens.
|
||||
#[derive(Clone)]
|
||||
pub struct Verifier {
|
||||
// Note that we do not need to use an Arc here, since Py is already a reference-counted
|
||||
// pointer
|
||||
inner: Py<PyAny>,
|
||||
timeout: u64,
|
||||
blocking_threadpool: Arc<BlockingThreadpool>,
|
||||
}
|
||||
|
||||
impl Verifier {
|
||||
const FILENAME: &'static str = "verify.py";
|
||||
|
||||
pub fn new(
|
||||
settings: &Settings,
|
||||
blocking_threadpool: Arc<BlockingThreadpool>,
|
||||
) -> Result<Self, TokenserverError> {
|
||||
let inner: Py<PyAny> = Python::with_gil::<_, Result<Py<PyAny>, TokenserverError>>(|py| {
|
||||
let code = include_str!("verify.py");
|
||||
let module = PyModule::from_code(py, code, Self::FILENAME, Self::FILENAME)
|
||||
.map_err(pyerr_to_tokenserver_error)?;
|
||||
let kwargs = {
|
||||
let dict = [("server_url", &settings.fxa_oauth_server_url)].into_py_dict(py);
|
||||
let parse_jwk = |jwk: &Jwk| {
|
||||
let (n, e) = match &jwk.algorithm {
|
||||
AlgorithmParameters::RSA(RSAKeyParameters { key_type: _, n, e }) => (n, e),
|
||||
_ => return Err(TokenserverError::internal_error()),
|
||||
};
|
||||
let alg = jwk
|
||||
.common
|
||||
.key_algorithm
|
||||
.ok_or_else(TokenserverError::internal_error)?
|
||||
.to_string();
|
||||
let kid = jwk
|
||||
.common
|
||||
.key_id
|
||||
.as_ref()
|
||||
.ok_or_else(TokenserverError::internal_error)?;
|
||||
if !matches!(
|
||||
jwk.common
|
||||
.public_key_use
|
||||
.as_ref()
|
||||
.ok_or_else(TokenserverError::internal_error)?,
|
||||
PublicKeyUse::Signature
|
||||
) {
|
||||
return Err(TokenserverError::internal_error());
|
||||
}
|
||||
|
||||
let dict = [
|
||||
("kty", "RSA"),
|
||||
("alg", &alg),
|
||||
("kid", kid),
|
||||
("use", "sig"),
|
||||
("n", &n),
|
||||
("e", e),
|
||||
]
|
||||
.into_py_dict(py);
|
||||
Ok(dict)
|
||||
};
|
||||
|
||||
let jwks = match (
|
||||
&settings.fxa_oauth_primary_jwk,
|
||||
&settings.fxa_oauth_secondary_jwk,
|
||||
) {
|
||||
(Some(primary_jwk), Some(secondary_jwk)) => {
|
||||
Some(vec![parse_jwk(primary_jwk)?, parse_jwk(secondary_jwk)?])
|
||||
}
|
||||
(Some(jwk), None) | (None, Some(jwk)) => Some(vec![parse_jwk(jwk)?]),
|
||||
(None, None) => None,
|
||||
};
|
||||
dict.set_item("jwks", jwks).unwrap();
|
||||
dict
|
||||
};
|
||||
let object: Py<PyAny> = module
|
||||
.getattr("FxaOAuthClient")
|
||||
.map_err(pyerr_to_tokenserver_error)?
|
||||
.call((), Some(kwargs))
|
||||
.map_err(|e| {
|
||||
e.print_and_set_sys_last_vars(py);
|
||||
pyerr_to_tokenserver_error(e)
|
||||
})?
|
||||
.into();
|
||||
|
||||
Ok(object)
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
timeout: settings.fxa_oauth_request_timeout,
|
||||
blocking_threadpool,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VerifyToken for Verifier {
|
||||
type Output = VerifyOutput;
|
||||
|
||||
/// Verifies an OAuth token. Returns `VerifyOutput` for valid tokens and a `TokenserverError`
|
||||
/// for invalid tokens.
|
||||
async fn verify(
|
||||
&self,
|
||||
token: String,
|
||||
_metrics: &Metrics,
|
||||
) -> Result<VerifyOutput, TokenserverError> {
|
||||
// We don't want to move `self` into the body of the closure here because we'd need to
|
||||
// clone it. Cloning it is only necessary if we need to verify the token remotely via FxA,
|
||||
// since that would require passing `self` to a separate thread. Passing &Self to a closure
|
||||
// gives us the flexibility to clone only when necessary.
|
||||
let verify_inner = |verifier: &Self| {
|
||||
let maybe_verify_output_string = Python::with_gil(|py| {
|
||||
let client = verifier.inner.as_ref(py);
|
||||
// `client.verify_token(token)`
|
||||
let result: &PyAny = client
|
||||
.getattr("verify_token")?
|
||||
.call((token,), None)
|
||||
.map_err(|e| {
|
||||
e.print_and_set_sys_last_vars(py);
|
||||
e
|
||||
})?;
|
||||
|
||||
if result.is_none() {
|
||||
Ok(None)
|
||||
} else {
|
||||
let verify_output_python_string = result.downcast::<PyString>()?;
|
||||
verify_output_python_string.extract::<String>().map(Some)
|
||||
}
|
||||
})
|
||||
.map_err(|e| TokenserverError {
|
||||
context: format!("pyo3 error in OAuth verifier: {}", e),
|
||||
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
|
||||
})?;
|
||||
|
||||
match maybe_verify_output_string {
|
||||
Some(verify_output_string) => {
|
||||
serde_json::from_str::<VerifyOutput>(&verify_output_string).map_err(|e| {
|
||||
TokenserverError {
|
||||
context: format!("Invalid OAuth verify output: {}", e),
|
||||
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
|
||||
}
|
||||
})
|
||||
}
|
||||
None => Err(TokenserverError {
|
||||
context: "Invalid OAuth token".to_owned(),
|
||||
..TokenserverError::invalid_credentials("Unauthorized".to_owned())
|
||||
}),
|
||||
}
|
||||
};
|
||||
|
||||
let verifier = self.clone();
|
||||
|
||||
// If the JWK is not cached or if the token is not a JWT/wasn't signed by a known key
|
||||
// type, PyFxA will make a request to FxA to retrieve it, blocking this thread. To
|
||||
// improve performance, we make the request on a thread in a threadpool specifically
|
||||
// used for blocking operations. The JWK should _always_ be cached in production to
|
||||
// maximize performance.
|
||||
let fut = self
|
||||
.blocking_threadpool
|
||||
.spawn(move || verify_inner(&verifier));
|
||||
|
||||
// The PyFxA OAuth client does not offer a way to set a request timeout, so we set one here
|
||||
// by timing out the future if the verification process blocks this thread for longer
|
||||
// than the specified number of seconds.
|
||||
time::timeout(Duration::from_secs(self.timeout), fut)
|
||||
.await
|
||||
.map_err(|_| TokenserverError {
|
||||
context: "OAuth verification timeout".to_owned(),
|
||||
..TokenserverError::resource_unavailable()
|
||||
})?
|
||||
}
|
||||
}
|
||||
|
||||
fn pyerr_to_tokenserver_error(e: PyErr) -> TokenserverError {
|
||||
TokenserverError {
|
||||
context: e.to_string(),
|
||||
..TokenserverError::internal_error()
|
||||
}
|
||||
}
|
||||
11
tokenserver-auth/src/token.rs
Normal file
11
tokenserver-auth/src/token.rs
Normal file
@ -0,0 +1,11 @@
|
||||
#[cfg(not(feature = "py"))]
|
||||
mod native;
|
||||
|
||||
#[cfg(feature = "py")]
|
||||
mod py;
|
||||
|
||||
#[cfg(feature = "py")]
|
||||
pub type Tokenlib = py::PyTokenlib;
|
||||
|
||||
#[cfg(not(feature = "py"))]
|
||||
pub type Tokenlib = native::Tokenlib;
|
||||
113
tokenserver-auth/src/token/native.rs
Normal file
113
tokenserver-auth/src/token/native.rs
Normal file
@ -0,0 +1,113 @@
|
||||
use crate::{
|
||||
crypto::{Crypto, CryptoImpl},
|
||||
MakeTokenPlaintext,
|
||||
};
|
||||
use base64::Engine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokenserver_common::TokenserverError;
|
||||
// Those two constants were pulled directly from
|
||||
// https://github.com/mozilla-services/tokenlib/blob/91ec9e2c922e55306eddba1394590a88f3b10602/tokenlib/__init__.py#L43-L45
|
||||
// We could change them, but we'd want to make sure that we also change them syncstorage, however
|
||||
// that would cause temporary auth issues for anyone with an old pre-new-value token
|
||||
const HKDF_SIGNING_INFO: &[u8] = b"services.mozilla.com/tokenlib/v1/signing";
|
||||
const HKDF_INFO_DERIVE: &[u8] = b"services.mozilla.com/tokenlib/v1/derive/";
|
||||
|
||||
pub struct Tokenlib {}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Token<'a> {
|
||||
#[serde(flatten)]
|
||||
plaintext: MakeTokenPlaintext,
|
||||
salt: &'a str,
|
||||
}
|
||||
|
||||
impl Tokenlib {
|
||||
pub fn get_token_and_derived_secret(
|
||||
plaintext: MakeTokenPlaintext,
|
||||
shared_secret: &str,
|
||||
) -> Result<(String, String), TokenserverError> {
|
||||
// First we make the token itself, the code blow was ported from:
|
||||
// https://github.com/mozilla-services/tokenlib/blob/91ec9e2c922e55306eddba1394590a88f3b10602/tokenlib/__init__.py#L96-L97
|
||||
let crypto_lib = CryptoImpl {};
|
||||
let mut salt_bytes = [0u8; 3];
|
||||
crypto_lib.rand_bytes(&mut salt_bytes)?;
|
||||
let salt = hex::encode(salt_bytes);
|
||||
let token_str = serde_json::to_string(&Token {
|
||||
plaintext,
|
||||
salt: &salt,
|
||||
})
|
||||
.map_err(|_| TokenserverError::internal_error())?;
|
||||
let hmac_key = crypto_lib.hkdf(shared_secret, None, HKDF_SIGNING_INFO)?;
|
||||
let signature = crypto_lib.hmac_sign(&hmac_key, token_str.as_bytes())?;
|
||||
let mut token_bytes = Vec::with_capacity(token_str.len() + signature.len());
|
||||
token_bytes.extend_from_slice(token_str.as_bytes());
|
||||
token_bytes.extend_from_slice(&signature);
|
||||
let token = base64::engine::general_purpose::URL_SAFE.encode(token_bytes);
|
||||
// Now that we finialized the token, lets generate our per token secret
|
||||
// The code below was ported from:
|
||||
// https://github.com/mozilla-services/tokenlib/blob/91ec9e2c922e55306eddba1394590a88f3b10602/tokenlib/__init__.py#L158-L159
|
||||
let mut info = Vec::with_capacity(HKDF_INFO_DERIVE.len() + token.as_bytes().len());
|
||||
info.extend_from_slice(HKDF_INFO_DERIVE);
|
||||
info.extend_from_slice(token.as_bytes());
|
||||
|
||||
let per_token_secret = crypto_lib.hkdf(shared_secret, Some(salt.as_bytes()), &info)?;
|
||||
let per_token_secret = base64::engine::general_purpose::URL_SAFE.encode(per_token_secret);
|
||||
Ok((token, per_token_secret))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{crypto::SHA256_OUTPUT_LEN, TokenserverOrigin};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_valid_token_and_per_token_secret() -> Result<(), TokenserverError> {
|
||||
// First we verify that the token we generated has a valid
|
||||
// and correct HMAC signature if signed using the same key
|
||||
let plaintext = MakeTokenPlaintext {
|
||||
node: "https://www.example.com".to_string(),
|
||||
fxa_kid: "kid".to_string(),
|
||||
fxa_uid: "user uid".to_string(),
|
||||
hashed_fxa_uid: "hased uid".to_string(),
|
||||
hashed_device_id: "hashed device id".to_string(),
|
||||
expires: 1031,
|
||||
uid: 13,
|
||||
tokenserver_origin: TokenserverOrigin::Rust,
|
||||
};
|
||||
let secret = "foobar";
|
||||
let crypto_impl = CryptoImpl {};
|
||||
let hmac_key = crypto_impl.hkdf(secret, None, HKDF_SIGNING_INFO).unwrap();
|
||||
let (b64_token, per_token_secret) =
|
||||
Tokenlib::get_token_and_derived_secret(plaintext.clone(), secret).unwrap();
|
||||
let token = base64::engine::general_purpose::URL_SAFE
|
||||
.decode(&b64_token)
|
||||
.unwrap();
|
||||
let token_size = token.len();
|
||||
let signature = &token[token_size - SHA256_OUTPUT_LEN..];
|
||||
let payload = &token[..token_size - SHA256_OUTPUT_LEN];
|
||||
crypto_impl
|
||||
.hmac_verify(&hmac_key, payload, signature)
|
||||
.unwrap();
|
||||
// Then we verify that the payload value we signed, is a valid
|
||||
// Token represented by our Token struct, and has exactly the same
|
||||
// plain_text values
|
||||
let token_data = serde_json::from_slice::<Token<'_>>(payload).unwrap();
|
||||
assert_eq!(token_data.plaintext, plaintext);
|
||||
// Finally, we verify that the same per_token_secret can be derived given the payload
|
||||
// and the shared secret
|
||||
let mut info = Vec::with_capacity(HKDF_INFO_DERIVE.len() + b64_token.as_bytes().len());
|
||||
info.extend_from_slice(HKDF_INFO_DERIVE);
|
||||
info.extend_from_slice(b64_token.as_bytes());
|
||||
|
||||
let expected_per_token_secret =
|
||||
crypto_impl.hkdf(secret, Some(token_data.salt.as_bytes()), &info)?;
|
||||
let expected_per_token_secret =
|
||||
base64::engine::general_purpose::URL_SAFE.encode(expected_per_token_secret);
|
||||
|
||||
assert_eq!(expected_per_token_secret, per_token_secret);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
71
tokenserver-auth/src/token/py.rs
Normal file
71
tokenserver-auth/src/token/py.rs
Normal file
@ -0,0 +1,71 @@
|
||||
use crate::{MakeTokenPlaintext, TokenserverError};
|
||||
use pyo3::{
|
||||
prelude::{IntoPy, PyErr, PyModule, PyObject, Python},
|
||||
types::IntoPyDict,
|
||||
};
|
||||
|
||||
pub struct PyTokenlib {}
|
||||
impl IntoPy<PyObject> for MakeTokenPlaintext {
|
||||
fn into_py(self, py: Python<'_>) -> PyObject {
|
||||
let dict = [
|
||||
("node", self.node),
|
||||
("fxa_kid", self.fxa_kid),
|
||||
("fxa_uid", self.fxa_uid),
|
||||
("hashed_device_id", self.hashed_device_id),
|
||||
("hashed_fxa_uid", self.hashed_fxa_uid),
|
||||
("tokenserver_origin", self.tokenserver_origin.to_string()),
|
||||
]
|
||||
.into_py_dict(py);
|
||||
|
||||
// These need to be set separately since they aren't strings, and
|
||||
// Rust doesn't support heterogeneous arrays
|
||||
dict.set_item("expires", self.expires).unwrap();
|
||||
dict.set_item("uid", self.uid).unwrap();
|
||||
|
||||
dict.into()
|
||||
}
|
||||
}
|
||||
impl PyTokenlib {
|
||||
pub fn get_token_and_derived_secret(
|
||||
plaintext: MakeTokenPlaintext,
|
||||
shared_secret: &str,
|
||||
) -> Result<(String, String), TokenserverError> {
|
||||
Python::with_gil(|py| {
|
||||
// `import tokenlib`
|
||||
let module = PyModule::import(py, "tokenlib").map_err(|e| {
|
||||
e.print_and_set_sys_last_vars(py);
|
||||
e
|
||||
})?;
|
||||
// `kwargs = { 'secret': shared_secret }`
|
||||
let kwargs = [("secret", shared_secret)].into_py_dict(py);
|
||||
// `token = tokenlib.make_token(plaintext, **kwargs)`
|
||||
let token = module
|
||||
.getattr("make_token")?
|
||||
.call((plaintext,), Some(kwargs))
|
||||
.map_err(|e| {
|
||||
e.print_and_set_sys_last_vars(py);
|
||||
e
|
||||
})
|
||||
.and_then(|x| x.extract())?;
|
||||
// `derived_secret = tokenlib.get_derived_secret(token, **kwargs)`
|
||||
let derived_secret = module
|
||||
.getattr("get_derived_secret")?
|
||||
.call((&token,), Some(kwargs))
|
||||
.map_err(|e| {
|
||||
e.print_and_set_sys_last_vars(py);
|
||||
e
|
||||
})
|
||||
.and_then(|x| x.extract())?;
|
||||
// `return (token, derived_secret)`
|
||||
Ok((token, derived_secret))
|
||||
})
|
||||
.map_err(pyerr_to_tokenserver_error)
|
||||
}
|
||||
}
|
||||
|
||||
fn pyerr_to_tokenserver_error(e: PyErr) -> TokenserverError {
|
||||
TokenserverError {
|
||||
context: e.to_string(),
|
||||
..TokenserverError::internal_error()
|
||||
}
|
||||
}
|
||||
@ -10,6 +10,7 @@ actix-web.workspace = true
|
||||
backtrace.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
||||
syncserver-common = { path = "../syncserver-common" }
|
||||
thiserror = "1.0.26"
|
||||
|
||||
@ -13,6 +13,7 @@ serde.workspace = true
|
||||
serde_derive.workspace = true
|
||||
serde_json.workspace = true
|
||||
slog-scope.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
||||
async-trait = "0.1.40"
|
||||
diesel = { version = "1.4", features = ["mysql", "r2d2"] }
|
||||
@ -20,7 +21,6 @@ diesel_logger = "0.1.1"
|
||||
diesel_migrations = { version = "1.4.0", features = ["mysql"] }
|
||||
syncserver-common = { path = "../syncserver-common" }
|
||||
syncserver-db-common = { path = "../syncserver-db-common" }
|
||||
thiserror = "1.0.26"
|
||||
tokenserver-common = { path = "../tokenserver-common" }
|
||||
tokenserver-settings = { path = "../tokenserver-settings" }
|
||||
tokio = { workspace = true, features = ["macros", "sync"] }
|
||||
|
||||
@ -7,5 +7,6 @@ edition.workspace=true
|
||||
|
||||
[dependencies]
|
||||
serde.workspace=true
|
||||
jsonwebtoken.workspace=true
|
||||
|
||||
tokenserver-common = { path = "../tokenserver-common" }
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use jsonwebtoken::jwk::Jwk;
|
||||
use serde::Deserialize;
|
||||
use tokenserver_common::NodeType;
|
||||
|
||||
@ -69,18 +70,6 @@ pub struct Settings {
|
||||
pub token_duration: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Jwk {
|
||||
pub kty: String,
|
||||
pub alg: String,
|
||||
pub kid: String,
|
||||
pub fxa_created_at: u64,
|
||||
#[serde(rename = "use")]
|
||||
pub use_of_key: String,
|
||||
pub n: String,
|
||||
pub e: String,
|
||||
}
|
||||
|
||||
impl Default for Settings {
|
||||
fn default() -> Settings {
|
||||
Settings {
|
||||
|
||||
@ -215,14 +215,14 @@ class TestE2e(TestCase, unittest.TestCase):
|
||||
raw = urlsafe_b64decode(res.json['id'])
|
||||
payload = raw[:-32]
|
||||
signature = raw[-32:]
|
||||
payload_dict = json.loads(payload.decode('utf-8'))
|
||||
payload_str = payload.decode('utf-8')
|
||||
payload_dict = json.loads(payload_str)
|
||||
# The `id` payload should include a field indicating the origin of the
|
||||
# token
|
||||
self.assertEqual(payload_dict['tokenserver_origin'], 'rust')
|
||||
signing_secret = self.TOKEN_SIGNING_SECRET
|
||||
expected_token = tokenlib.make_token(payload_dict,
|
||||
secret=signing_secret)
|
||||
expected_signature = urlsafe_b64decode(expected_token)[-32:]
|
||||
tm = tokenlib.TokenManager(secret=signing_secret)
|
||||
expected_signature = tm._get_signature(payload_str.encode('utf8'))
|
||||
# Using the #compare_digest method here is not strictly necessary, as
|
||||
# this is not a security-sensitive situation, but it's good practice
|
||||
self.assertTrue(hmac.compare_digest(expected_signature, signature))
|
||||
@ -271,12 +271,11 @@ class TestE2e(TestCase, unittest.TestCase):
|
||||
raw = urlsafe_b64decode(res.json['id'])
|
||||
payload = raw[:-32]
|
||||
signature = raw[-32:]
|
||||
payload_dict = json.loads(payload.decode('utf-8'))
|
||||
payload_str = payload.decode('utf-8')
|
||||
|
||||
signing_secret = self.TOKEN_SIGNING_SECRET
|
||||
expected_token = tokenlib.make_token(payload_dict,
|
||||
secret=signing_secret)
|
||||
expected_signature = urlsafe_b64decode(expected_token)[-32:]
|
||||
tm = tokenlib.TokenManager(secret=signing_secret)
|
||||
expected_signature = tm._get_signature(payload_str.encode('utf8'))
|
||||
# Using the #compare_digest method here is not strictly necessary, as
|
||||
# this is not a security-sensitive situation, but it's good practice
|
||||
self.assertTrue(hmac.compare_digest(expected_signature, signature))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user