From 77b34f88338bad79499f54255b5480bdf73b0e58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Le=C3=A9h?= <52788117+Ptrskay3@users.noreply.github.com> Date: Wed, 19 Jul 2023 10:31:36 +0200 Subject: [PATCH] cors: Add support for private network preflights Co-authored-by: Jonas Platte --- tower-http/Cargo.toml | 2 +- tower-http/src/cors/allow_credentials.rs | 2 + tower-http/src/cors/allow_private_network.rs | 196 +++++++++++++++++++ tower-http/src/cors/mod.rs | 36 +++- 4 files changed, 234 insertions(+), 2 deletions(-) create mode 100644 tower-http/src/cors/allow_private_network.rs diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 6e145fb2..f7408d87 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -17,7 +17,7 @@ bitflags = "2.0.2" bytes = "1" futures-core = "0.3" futures-util = { version = "0.3.14", default_features = false, features = [] } -http = "0.2.2" +http = "0.2.7" http-body = "0.4.5" pin-project-lite = "0.2.7" tower-layer = "0.3" diff --git a/tower-http/src/cors/allow_credentials.rs b/tower-http/src/cors/allow_credentials.rs index 3843def8..e489c570 100644 --- a/tower-http/src/cors/allow_credentials.rs +++ b/tower-http/src/cors/allow_credentials.rs @@ -27,6 +27,8 @@ impl AllowCredentials { /// Allow credentials for some requests, based on a given predicate /// + /// The first argument to the predicate is the request origin. + /// /// See [`CorsLayer::allow_credentials`] for more details. /// /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials diff --git a/tower-http/src/cors/allow_private_network.rs b/tower-http/src/cors/allow_private_network.rs new file mode 100644 index 00000000..4163014d --- /dev/null +++ b/tower-http/src/cors/allow_private_network.rs @@ -0,0 +1,196 @@ +use std::{fmt, sync::Arc}; + +use http::{ + header::{HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Private-Network`][wicg] header. +/// +/// See [`CorsLayer::allow_private_network`] for more details. +/// +/// [wicg]: https://wicg.github.io/private-network-access/ +/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network +#[derive(Clone, Default)] +#[must_use] +pub struct AllowPrivateNetwork(AllowPrivateNetworkInner); + +impl AllowPrivateNetwork { + /// Allow requests via a more private network than the one used to access the origin + /// + /// See [`CorsLayer::allow_private_network`] for more details. + /// + /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network + pub fn yes() -> Self { + Self(AllowPrivateNetworkInner::Yes) + } + + /// Allow requests via private network for some requests, based on a given predicate + /// + /// The first argument to the predicate is the request origin. + /// + /// See [`CorsLayer::allow_private_network`] for more details. + /// + /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network + pub fn predicate(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, + { + Self(AllowPrivateNetworkInner::Predicate(Arc::new(f))) + } + + pub(super) fn to_header( + &self, + origin: Option<&HeaderValue>, + parts: &RequestParts, + ) -> Option<(HeaderName, HeaderValue)> { + #[allow(clippy::declare_interior_mutable_const)] + const REQUEST_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-request-private-network"); + + #[allow(clippy::declare_interior_mutable_const)] + const ALLOW_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-allow-private-network"); + + const TRUE: HeaderValue = HeaderValue::from_static("true"); + + // Cheapest fallback: allow_private_network hasn't been set + if let AllowPrivateNetworkInner::No = &self.0 { + return None; + } + + // Access-Control-Allow-Private-Network is only relevant if the request + // has the Access-Control-Request-Private-Network header set, else skip + if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) { + return None; + } + + let allow_private_network = match &self.0 { + AllowPrivateNetworkInner::Yes => true, + AllowPrivateNetworkInner::No => false, // unreachable, but not harmful + AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts), + }; + + allow_private_network.then(|| (ALLOW_PRIVATE_NETWORK, TRUE)) + } +} + +impl From for AllowPrivateNetwork { + fn from(v: bool) -> Self { + match v { + true => Self(AllowPrivateNetworkInner::Yes), + false => Self(AllowPrivateNetworkInner::No), + } + } +} + +impl fmt::Debug for AllowPrivateNetwork { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(), + AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(), + AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(), + } + } +} + +#[derive(Clone)] +enum AllowPrivateNetworkInner { + Yes, + No, + Predicate( + Arc Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, + ), +} + +impl Default for AllowPrivateNetworkInner { + fn default() -> Self { + Self::No + } +} + +#[cfg(test)] +mod tests { + use super::AllowPrivateNetwork; + use crate::cors::CorsLayer; + + use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response}; + use hyper::Body; + use tower::{BoxError, ServiceBuilder, ServiceExt}; + use tower_service::Service; + + const REQUEST_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-request-private-network"); + + const ALLOW_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-allow-private-network"); + + const TRUE: HeaderValue = HeaderValue::from_static("true"); + + #[tokio::test] + async fn cors_private_network_header_is_added_correctly() { + let mut service = ServiceBuilder::new() + .layer(CorsLayer::new().allow_private_network(true)) + .service_fn(echo); + + let req = Request::builder() + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .body(Body::empty()) + .unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + + assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + + assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); + } + + #[tokio::test] + async fn cors_private_network_header_is_added_correctly_with_predicate() { + let allow_private_network = + AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| { + parts.uri.path() == "/allow-private" && origin == "localhost" + }); + let mut service = ServiceBuilder::new() + .layer(CorsLayer::new().allow_private_network(allow_private_network)) + .service_fn(echo); + + let req = Request::builder() + .header(ORIGIN, "localhost") + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .uri("/allow-private") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE); + + let req = Request::builder() + .header(ORIGIN, "localhost") + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .uri("/other") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(req).await.unwrap(); + + assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); + + let req = Request::builder() + .header(ORIGIN, "not-localhost") + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .uri("/allow-private") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(req).await.unwrap(); + + assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); + } + + async fn echo(req: Request) -> Result, BoxError> { + Ok(Response::new(req.into_body())) + } +} diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index 9a952a2c..42c355b5 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -69,13 +69,15 @@ mod allow_credentials; mod allow_headers; mod allow_methods; mod allow_origin; +mod allow_private_network; mod expose_headers; mod max_age; mod vary; pub use self::{ allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods, - allow_origin::AllowOrigin, expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary, + allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork, + expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary, }; /// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn]. @@ -90,6 +92,7 @@ pub struct CorsLayer { allow_headers: AllowHeaders, allow_methods: AllowMethods, allow_origin: AllowOrigin, + allow_private_network: AllowPrivateNetwork, expose_headers: ExposeHeaders, max_age: MaxAge, vary: Vary, @@ -112,6 +115,7 @@ impl CorsLayer { allow_headers: Default::default(), allow_methods: Default::default(), allow_origin: Default::default(), + allow_private_network: Default::default(), expose_headers: Default::default(), max_age: Default::default(), vary: Default::default(), @@ -360,6 +364,23 @@ impl CorsLayer { self } + /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header. + /// + /// ``` + /// use tower_http::cors::CorsLayer; + /// + /// let layer = CorsLayer::new().allow_private_network(true); + /// ``` + /// + /// [wicg]: https://wicg.github.io/private-network-access/ + pub fn allow_private_network(mut self, allow_private_network: T) -> Self + where + T: Into, + { + self.allow_private_network = allow_private_network.into(); + self + } + /// Set the value(s) of the [`Vary`][mdn] header. /// /// In contrast to the other headers, this one has a non-empty default of @@ -554,6 +575,18 @@ impl Cors { self.map_layer(|layer| layer.expose_headers(headers)) } + /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header. + /// + /// See [`CorsLayer::allow_private_network`] for more details. + /// + /// [wicg]: https://wicg.github.io/private-network-access/ + pub fn allow_private_network(self, allow_private_network: T) -> Self + where + T: Into, + { + self.map_layer(|layer| layer.allow_private_network(allow_private_network)) + } + fn map_layer(mut self, f: F) -> Self where F: FnOnce(CorsLayer) -> CorsLayer, @@ -588,6 +621,7 @@ where headers.extend(self.layer.allow_origin.to_header(origin, &parts)); headers.extend(self.layer.allow_credentials.to_header(origin, &parts)); + headers.extend(self.layer.allow_private_network.to_header(origin, &parts)); let mut vary_headers = self.layer.vary.values(); if let Some(first) = vary_headers.next() {