diff --git a/src/api/common/cors.rs b/src/api/common/cors.rs index 6f524bf4..94ff8aa9 100644 --- a/src/api/common/cors.rs +++ b/src/api/common/cors.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use http::header::{ - ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, - ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, + HeaderValue, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, + ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_REQUEST_HEADERS, + ACCESS_CONTROL_REQUEST_METHOD, VARY, }; use hyper::{body::Body, body::Incoming as IncomingBody, Request, Response, StatusCode}; @@ -12,10 +13,12 @@ use garage_model::garage::Garage; use crate::common_error::{CommonError, OkOrBadRequest, OkOrInternalError}; use crate::helpers::*; +// Return both the matching rule and the parsed Origin header so callers that +// apply CORS headers don't have to repeat Origin lookup and validation. pub fn find_matching_cors_rule<'a, B>( bucket_params: &'a BucketParams, - req: &Request, -) -> Result, CommonError> { + req: &'a Request, +) -> Result, CommonError> { if let Some(cors_config) = bucket_params.cors_config.get() { if let Some(origin) = req.headers().get("Origin") { let origin = origin.to_str()?; @@ -23,9 +26,12 @@ pub fn find_matching_cors_rule<'a, B>( Some(h) => h.to_str()?.split(',').map(|h| h.trim()).collect::>(), None => vec![], }; - return Ok(cors_config.iter().find(|rule| { - cors_rule_matches(rule, origin, req.method().as_ref(), request_headers.iter()) - })); + return Ok(cors_config + .iter() + .find(|rule| { + cors_rule_matches(rule, origin, req.method().as_ref(), request_headers.iter()) + }) + .map(|rule| (rule, origin))); } } Ok(None) @@ -53,12 +59,16 @@ where pub fn add_cors_headers( resp: &mut Response, rule: &GarageCorsRule, + request_origin: &str, ) -> Result<(), http::header::InvalidHeaderValue> { let h = resp.headers_mut(); - h.insert( - ACCESS_CONTROL_ALLOW_ORIGIN, - rule.allow_origins.join(", ").parse()?, - ); + let is_wildcard_origin = rule.allow_origins.iter().any(|origin| origin == "*"); + let allow_origin = if is_wildcard_origin { + "*" + } else { + request_origin + }; + h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin.parse()?); h.insert( ACCESS_CONTROL_ALLOW_METHODS, rule.allow_methods.join(", ").parse()?, @@ -71,6 +81,12 @@ pub fn add_cors_headers( ACCESS_CONTROL_EXPOSE_HEADERS, rule.expose_headers.join(", ").parse()?, ); + // When ACAO reflects the request origin instead of returning "*", + // caches must vary on the Origin request header to avoid reusing + // a response generated for one origin when serving another origin. + if !is_wildcard_origin { + h.insert(VARY, HeaderValue::from_static("Origin")); + } Ok(()) } @@ -149,7 +165,17 @@ pub fn handle_options_for_bucket( let mut resp = Response::builder() .status(StatusCode::OK) .body(EmptyBody::new())?; - add_cors_headers(&mut resp, rule).ok_or_internal_error("Invalid CORS configuration")?; + add_cors_headers(&mut resp, rule, origin) + .ok_or_internal_error("Invalid CORS configuration")?; + // Preflight responses vary not only on Origin but also on the + // requested method and requested headers, so caches must not + // reuse one preflight decision for a different preflight input. + resp.headers_mut().insert( + VARY, + "Origin, Access-Control-Request-Method, Access-Control-Request-Headers" + .parse() + .expect("static vary header"), + ); return Ok(resp); } } @@ -158,3 +184,98 @@ pub fn handle_options_for_bucket( "This CORS request is not allowed.".into(), )) } + +#[cfg(test)] +mod tests { + use super::*; + + fn bucket_params_with_rule(allow_origins: Vec<&str>) -> BucketParams { + let mut bucket_params = BucketParams::default(); + bucket_params.cors_config.update(Some(vec![GarageCorsRule { + id: Some("cors-test".into()), + max_age_seconds: None, + allow_origins: allow_origins.into_iter().map(str::to_string).collect(), + allow_methods: vec!["GET".into(), "PUT".into()], + allow_headers: vec!["*".into()], + expose_headers: vec![], + }])); + bucket_params + } + + fn preflight_request(origin: &str) -> Request<()> { + Request::builder() + .method("OPTIONS") + .uri("http://example.test/bucket") + .header("Origin", origin) + .header(ACCESS_CONTROL_REQUEST_METHOD, "PUT") + .body(()) + .unwrap() + } + + #[test] + fn preflight_with_single_allowed_origin_returns_request_origin() { + let bucket_params = bucket_params_with_rule(vec!["https://app.example.test"]); + let req = preflight_request("https://app.example.test"); + + let resp = handle_options_for_bucket(&req, &bucket_params).unwrap(); + + assert_eq!( + resp.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), + "https://app.example.test" + ); + let vary_values: Vec<_> = resp + .headers() + .get_all(VARY) + .iter() + .map(|value| value.to_str().unwrap()) + .collect(); + assert_eq!( + vary_values, + vec!["Origin, Access-Control-Request-Method, Access-Control-Request-Headers",] + ); + } + + #[test] + fn preflight_with_multiple_allowed_origins_reflects_request_origin() { + let bucket_params = bucket_params_with_rule(vec![ + "https://app.example.test", + "https://admin.example.test", + ]); + let req = preflight_request("https://app.example.test"); + + let resp = handle_options_for_bucket(&req, &bucket_params).unwrap(); + + // This assertion documents the behavior browsers expect: + // even if multiple origins are allowed by configuration, the + // response should reflect the request origin rather than emit + // a comma-separated list. It currently fails and is meant to + // turn green once header generation is corrected. + assert_eq!( + resp.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), + "https://app.example.test" + ); + } + + #[test] + fn preflight_with_wildcard_allowed_origin_returns_wildcard() { + let bucket_params = bucket_params_with_rule(vec!["*"]); + let req = preflight_request("https://app.example.test"); + + let resp = handle_options_for_bucket(&req, &bucket_params).unwrap(); + + assert_eq!( + resp.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), + "*" + ); + let vary_values: Vec<_> = resp + .headers() + .get_all(VARY) + .iter() + .map(|value| value.to_str().unwrap()) + .collect(); + assert_eq!( + vary_values, + vec!["Origin, Access-Control-Request-Method, Access-Control-Request-Headers",] + ); + } +} diff --git a/src/api/k2v/api_server.rs b/src/api/k2v/api_server.rs index 8c89c35d..6bc6eaaa 100644 --- a/src/api/k2v/api_server.rs +++ b/src/api/k2v/api_server.rs @@ -111,7 +111,7 @@ impl ApiHandler for K2VApiServer { Method::GET | Method::HEAD | Method::POST => { find_matching_cors_rule(&bucket_params, &req) .ok_or_internal_error("Error looking up CORS rule")? - .cloned() + .map(|(rule, origin)| (rule.clone(), origin.to_string())) } _ => None, }; @@ -164,8 +164,8 @@ impl ApiHandler for K2VApiServer { // If request was a success and we have a CORS rule that applies to it, // add the corresponding CORS headers to the response let mut resp_ok = resp?; - if let Some(rule) = matching_cors_rule { - add_cors_headers(&mut resp_ok, &rule) + if let Some((rule, origin)) = matching_cors_rule { + add_cors_headers(&mut resp_ok, &rule, &origin) .ok_or_internal_error("Invalid bucket CORS configuration")?; } diff --git a/src/api/s3/api_server.rs b/src/api/s3/api_server.rs index 689d174d..ce1c2813 100644 --- a/src/api/s3/api_server.rs +++ b/src/api/s3/api_server.rs @@ -159,7 +159,8 @@ impl ApiHandler for S3ApiServer { return Err(Error::forbidden("Operation is not allowed for this key.")); } - let matching_cors_rule = find_matching_cors_rule(&bucket_params, &req)?.cloned(); + let matching_cors = find_matching_cors_rule(&bucket_params, &req)? + .map(|(rule, origin)| (rule.clone(), origin.to_string())); let ctx = ReqCtx { garage, @@ -334,8 +335,8 @@ impl ApiHandler for S3ApiServer { // If request was a success and we have a CORS rule that applies to it, // add the corresponding CORS headers to the response let mut resp_ok = resp?; - if let Some(rule) = matching_cors_rule { - add_cors_headers(&mut resp_ok, &rule) + if let Some((rule, origin)) = matching_cors { + add_cors_headers(&mut resp_ok, &rule, &origin) .ok_or_internal_error("Invalid bucket CORS configuration")?; } diff --git a/src/api/s3/post_object.rs b/src/api/s3/post_object.rs index 26fd454c..711e244b 100644 --- a/src/api/s3/post_object.rs +++ b/src/api/s3/post_object.rs @@ -121,7 +121,7 @@ pub async fn handle_post_object( &bucket_params, &Request::from_parts(head.clone(), empty_body::()), )? - .cloned(); + .map(|(rule, origin)| (rule.clone(), origin.to_string())); let decoded_policy = BASE64_STANDARD .decode(policy) @@ -351,8 +351,8 @@ pub async fn handle_post_object( } }; - if let Some(rule) = matching_cors_rule { - add_cors_headers(&mut resp, &rule) + if let Some((rule, origin)) = matching_cors_rule { + add_cors_headers(&mut resp, &rule, &origin) .ok_or_internal_error("Invalid bucket CORS configuration")?; } diff --git a/src/garage/tests/s3/cors.rs b/src/garage/tests/s3/cors.rs new file mode 100644 index 00000000..437e3e2d --- /dev/null +++ b/src/garage/tests/s3/cors.rs @@ -0,0 +1,121 @@ +use aws_sdk_s3::types::{CorsConfiguration, CorsRule}; +use hyper::{Method, StatusCode}; + +use crate::common; + +const REQUEST_ORIGIN: &str = "https://app.example.test"; +const SECOND_ALLOWED_ORIGIN: &str = "https://admin.example.test"; +const OBJECT_KEY: &str = "probe.txt"; +const BODY: &[u8] = b"hello from integration repro\n"; + +async fn send_preflight( + ctx: &common::Context, + bucket: &str, + origin: &str, +) -> hyper::Response { + ctx.custom_request + .builder(bucket.to_string()) + .method(Method::OPTIONS) + .path(OBJECT_KEY) + .unsigned_header("origin", origin) + .unsigned_header("access-control-request-method", "PUT") + .unsigned_header( + "access-control-request-headers", + "content-type,x-amz-meta-demo", + ) + .body(vec![]) + .send() + .await + .unwrap() +} + +async fn send_put( + ctx: &common::Context, + bucket: &str, + origin: &str, +) -> hyper::Response { + ctx.custom_request + .builder(bucket.to_string()) + .method(Method::PUT) + .path(OBJECT_KEY) + .signed_header("content-type", "text/plain") + .signed_header("x-amz-meta-demo", "1") + .unsigned_header("origin", origin) + .body(BODY.to_vec()) + .send() + .await + .unwrap() +} + +async fn apply_bucket_cors(ctx: &common::Context, bucket: &str, allowed_origins: &[&str]) { + let rule = allowed_origins.iter().fold( + CorsRule::builder() + .allowed_headers("*") + .allowed_methods("PUT") + .expose_headers("ETag"), + |rule, origin| rule.allowed_origins(*origin), + ); + + let cors = CorsConfiguration::builder() + .cors_rules(rule.build().unwrap()) + .build() + .unwrap(); + + ctx.client + .put_bucket_cors() + .bucket(bucket) + .cors_configuration(cors) + .send() + .await + .unwrap(); +} + +#[tokio::test] +async fn test_s3_api_cors_reflects_request_origin() { + let ctx = common::context(); + let bucket = ctx.create_bucket("s3-cors-direct"); + + apply_bucket_cors(&ctx, &bucket, &[REQUEST_ORIGIN]).await; + + let control_preflight = send_preflight(&ctx, &bucket, REQUEST_ORIGIN).await; + assert_eq!(control_preflight.status(), StatusCode::OK); + assert_eq!( + control_preflight + .headers() + .get("access-control-allow-origin") + .unwrap(), + REQUEST_ORIGIN + ); + + let control_put = send_put(&ctx, &bucket, REQUEST_ORIGIN).await; + assert_eq!(control_put.status(), StatusCode::OK); + assert_eq!( + control_put + .headers() + .get("access-control-allow-origin") + .unwrap(), + REQUEST_ORIGIN + ); + + apply_bucket_cors(&ctx, &bucket, &[REQUEST_ORIGIN, SECOND_ALLOWED_ORIGIN]).await; + + let repro_preflight = send_preflight(&ctx, &bucket, REQUEST_ORIGIN).await; + assert_eq!(repro_preflight.status(), StatusCode::OK); + assert_eq!( + repro_preflight + .headers() + .get("access-control-allow-origin") + .unwrap(), + REQUEST_ORIGIN + ); + + let repro_put = send_put(&ctx, &bucket, REQUEST_ORIGIN).await; + assert_eq!(repro_put.status(), StatusCode::OK); + assert_eq!( + repro_put + .headers() + .get("access-control-allow-origin") + .unwrap(), + REQUEST_ORIGIN + ); +} diff --git a/src/garage/tests/s3/mod.rs b/src/garage/tests/s3/mod.rs index fa081389..bf217513 100644 --- a/src/garage/tests/s3/mod.rs +++ b/src/garage/tests/s3/mod.rs @@ -1,3 +1,4 @@ +mod cors; mod list; mod multipart; mod objects; diff --git a/src/web/web_server.rs b/src/web/web_server.rs index 2d0cac2d..6aac096e 100644 --- a/src/web/web_server.rs +++ b/src/web/web_server.rs @@ -405,8 +405,8 @@ impl WebServer { } Ok(mut resp) => { // Maybe add CORS headers - if let Some(rule) = find_matching_cors_rule(&bucket_params, req)? { - add_cors_headers(&mut resp, rule) + if let Some((rule, origin)) = find_matching_cors_rule(&bucket_params, req)? { + add_cors_headers(&mut resp, rule, origin) .ok_or_internal_error("Invalid bucket CORS configuration")?; } Ok(resp)