Skip to content

Commit

Permalink
cors: Add support for private network preflights
Browse files Browse the repository at this point in the history
Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
  • Loading branch information
Ptrskay3 and jplatte authored Jul 19, 2023
1 parent 5afd958 commit 77b34f8
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tower-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ bitflags = "2.0.2"
bytes = "1"
futures-core = "0.3"
futures-util = { version = "0.3.14", default_features = false, features = [] }
http = "0.2.2"
http = "0.2.7"
http-body = "0.4.5"
pin-project-lite = "0.2.7"
tower-layer = "0.3"
Expand Down
2 changes: 2 additions & 0 deletions tower-http/src/cors/allow_credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ impl AllowCredentials {

/// Allow credentials for some requests, based on a given predicate
///
/// The first argument to the predicate is the request origin.
///
/// See [`CorsLayer::allow_credentials`] for more details.
///
/// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials
Expand Down
196 changes: 196 additions & 0 deletions tower-http/src/cors/allow_private_network.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use std::{fmt, sync::Arc};

use http::{
header::{HeaderName, HeaderValue},
request::Parts as RequestParts,
};

/// Holds configuration for how to set the [`Access-Control-Allow-Private-Network`][wicg] header.
///
/// See [`CorsLayer::allow_private_network`] for more details.
///
/// [wicg]: https://wicg.github.io/private-network-access/
/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
#[derive(Clone, Default)]
#[must_use]
pub struct AllowPrivateNetwork(AllowPrivateNetworkInner);

impl AllowPrivateNetwork {
/// Allow requests via a more private network than the one used to access the origin
///
/// See [`CorsLayer::allow_private_network`] for more details.
///
/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
pub fn yes() -> Self {
Self(AllowPrivateNetworkInner::Yes)
}

/// Allow requests via private network for some requests, based on a given predicate
///
/// The first argument to the predicate is the request origin.
///
/// See [`CorsLayer::allow_private_network`] for more details.
///
/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
pub fn predicate<F>(f: F) -> Self
where
F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
{
Self(AllowPrivateNetworkInner::Predicate(Arc::new(f)))
}

pub(super) fn to_header(
&self,
origin: Option<&HeaderValue>,
parts: &RequestParts,
) -> Option<(HeaderName, HeaderValue)> {
#[allow(clippy::declare_interior_mutable_const)]
const REQUEST_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-request-private-network");

#[allow(clippy::declare_interior_mutable_const)]
const ALLOW_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-allow-private-network");

const TRUE: HeaderValue = HeaderValue::from_static("true");

// Cheapest fallback: allow_private_network hasn't been set
if let AllowPrivateNetworkInner::No = &self.0 {
return None;
}

// Access-Control-Allow-Private-Network is only relevant if the request
// has the Access-Control-Request-Private-Network header set, else skip
if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) {
return None;
}

let allow_private_network = match &self.0 {
AllowPrivateNetworkInner::Yes => true,
AllowPrivateNetworkInner::No => false, // unreachable, but not harmful
AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts),
};

allow_private_network.then(|| (ALLOW_PRIVATE_NETWORK, TRUE))
}
}

impl From<bool> for AllowPrivateNetwork {
fn from(v: bool) -> Self {
match v {
true => Self(AllowPrivateNetworkInner::Yes),
false => Self(AllowPrivateNetworkInner::No),
}
}
}

impl fmt::Debug for AllowPrivateNetwork {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(),
AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(),
AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
}
}
}

#[derive(Clone)]
enum AllowPrivateNetworkInner {
Yes,
No,
Predicate(
Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
),
}

impl Default for AllowPrivateNetworkInner {
fn default() -> Self {
Self::No
}
}

#[cfg(test)]
mod tests {
use super::AllowPrivateNetwork;
use crate::cors::CorsLayer;

use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response};
use hyper::Body;
use tower::{BoxError, ServiceBuilder, ServiceExt};
use tower_service::Service;

const REQUEST_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-request-private-network");

const ALLOW_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-allow-private-network");

const TRUE: HeaderValue = HeaderValue::from_static("true");

#[tokio::test]
async fn cors_private_network_header_is_added_correctly() {
let mut service = ServiceBuilder::new()
.layer(CorsLayer::new().allow_private_network(true))
.service_fn(echo);

let req = Request::builder()
.header(REQUEST_PRIVATE_NETWORK, TRUE)
.body(Body::empty())
.unwrap();
let res = service.ready().await.unwrap().call(req).await.unwrap();

assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);

let req = Request::builder().body(Body::empty()).unwrap();
let res = service.ready().await.unwrap().call(req).await.unwrap();

assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
}

#[tokio::test]
async fn cors_private_network_header_is_added_correctly_with_predicate() {
let allow_private_network =
AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| {
parts.uri.path() == "/allow-private" && origin == "localhost"
});
let mut service = ServiceBuilder::new()
.layer(CorsLayer::new().allow_private_network(allow_private_network))
.service_fn(echo);

let req = Request::builder()
.header(ORIGIN, "localhost")
.header(REQUEST_PRIVATE_NETWORK, TRUE)
.uri("/allow-private")
.body(Body::empty())
.unwrap();

let res = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);

let req = Request::builder()
.header(ORIGIN, "localhost")
.header(REQUEST_PRIVATE_NETWORK, TRUE)
.uri("/other")
.body(Body::empty())
.unwrap();

let res = service.ready().await.unwrap().call(req).await.unwrap();

assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());

let req = Request::builder()
.header(ORIGIN, "not-localhost")
.header(REQUEST_PRIVATE_NETWORK, TRUE)
.uri("/allow-private")
.body(Body::empty())
.unwrap();

let res = service.ready().await.unwrap().call(req).await.unwrap();

assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
}

async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}
36 changes: 35 additions & 1 deletion tower-http/src/cors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ mod allow_credentials;
mod allow_headers;
mod allow_methods;
mod allow_origin;
mod allow_private_network;
mod expose_headers;
mod max_age;
mod vary;

pub use self::{
allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
allow_origin::AllowOrigin, expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork,
expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
};

/// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn].
Expand All @@ -90,6 +92,7 @@ pub struct CorsLayer {
allow_headers: AllowHeaders,
allow_methods: AllowMethods,
allow_origin: AllowOrigin,
allow_private_network: AllowPrivateNetwork,
expose_headers: ExposeHeaders,
max_age: MaxAge,
vary: Vary,
Expand All @@ -112,6 +115,7 @@ impl CorsLayer {
allow_headers: Default::default(),
allow_methods: Default::default(),
allow_origin: Default::default(),
allow_private_network: Default::default(),
expose_headers: Default::default(),
max_age: Default::default(),
vary: Default::default(),
Expand Down Expand Up @@ -360,6 +364,23 @@ impl CorsLayer {
self
}

/// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
///
/// ```
/// use tower_http::cors::CorsLayer;
///
/// let layer = CorsLayer::new().allow_private_network(true);
/// ```
///
/// [wicg]: https://wicg.github.io/private-network-access/
pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
where
T: Into<AllowPrivateNetwork>,
{
self.allow_private_network = allow_private_network.into();
self
}

/// Set the value(s) of the [`Vary`][mdn] header.
///
/// In contrast to the other headers, this one has a non-empty default of
Expand Down Expand Up @@ -554,6 +575,18 @@ impl<S> Cors<S> {
self.map_layer(|layer| layer.expose_headers(headers))
}

/// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
///
/// See [`CorsLayer::allow_private_network`] for more details.
///
/// [wicg]: https://wicg.github.io/private-network-access/
pub fn allow_private_network<T>(self, allow_private_network: T) -> Self
where
T: Into<AllowPrivateNetwork>,
{
self.map_layer(|layer| layer.allow_private_network(allow_private_network))
}

fn map_layer<F>(mut self, f: F) -> Self
where
F: FnOnce(CorsLayer) -> CorsLayer,
Expand Down Expand Up @@ -588,6 +621,7 @@ where

headers.extend(self.layer.allow_origin.to_header(origin, &parts));
headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
headers.extend(self.layer.allow_private_network.to_header(origin, &parts));

let mut vary_headers = self.layer.vary.values();
if let Some(first) = vary_headers.next() {
Expand Down

0 comments on commit 77b34f8

Please sign in to comment.