Skip to content

Commit

Permalink
Add mtls::peer_certificates filter
Browse files Browse the repository at this point in the history
  • Loading branch information
lann authored and joelweinberger committed Oct 14, 2024
1 parent ce8114b commit a2defac
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 21 deletions.
10 changes: 5 additions & 5 deletions src/filter/service.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};

Expand All @@ -11,6 +10,7 @@ use pin_project::pin_project;
use crate::reject::IsReject;
use crate::reply::{Reply, Response};
use crate::route::{self, Route};
use crate::transport::PeerInfo;
use crate::{Filter, Request};

/// Convert a `Filter` into a `Service`.
Expand Down Expand Up @@ -70,14 +70,14 @@ where
<F::Future as TryFuture>::Error: IsReject,
{
#[inline]
pub(crate) fn call_with_addr(
pub(crate) fn call_with_peer_info(
&self,
req: Request,
remote_addr: Option<SocketAddr>,
peer_info: PeerInfo,
) -> FilteredFuture<F::Future> {
debug_assert!(!route::is_set(), "nested route::set calls");

let route = Route::new(req, remote_addr);
let route = Route::new(req, peer_info);
let fut = route::set(&route, || self.filter.filter(super::Internal));
FilteredFuture { future: fut, route }
}
Expand All @@ -99,7 +99,7 @@ where

#[inline]
fn call(&mut self, req: Request) -> Self::Future {
self.call_with_addr(req, None)
self.call_with_peer_info(req, Default::default())
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/filters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub mod path;
pub mod query;
pub mod reply;
pub mod sse;
#[cfg(feature = "tls")]
pub mod mtls;
pub mod trace;
#[cfg(feature = "websocket")]
pub mod ws;
Expand Down
61 changes: 61 additions & 0 deletions src/filters/mtls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//! Mutual (client) TLS filters.
use std::convert::Infallible;

use tokio_rustls::rustls::Certificate;

use crate::{
filter::{filter_fn_one, Filter},
route::Route,
};

/// Creates a `Filter` to get the peer certificates for the TLS connection.
///
/// If the underlying transport doesn't have peer certificates, this will yield
/// `None`.
///
/// # Example
///
/// ```
/// use warp::mtls::Certificates;
/// use warp::Filter;
///
/// let route = warp::mtls::peer_certificates()
/// .map(|certs: Option<Certificates>| {
/// println!("peer certificates = {:?}", certs.as_ref());
/// });
/// ```
pub fn peer_certificates(
) -> impl Filter<Extract = (Option<Certificates>,), Error = Infallible> + Copy {
filter_fn_one(|route| futures_util::future::ok(Certificates::from_route(route)))
}

/// Certificates is a iterable container of Certificates.
#[derive(Debug)]
pub struct Certificates(Vec<Certificate>);

impl Certificates {
fn from_route(route: &Route) -> Option<Certificates> {
route
.peer_certificates()
.read()
.unwrap()
.as_ref()
.map(|certs| Self(certs.to_vec()))
}
}

impl AsRef<[Certificate]> for Certificates {
fn as_ref(&self) -> &[Certificate] {
self.0.as_ref()
}
}

impl<'a> IntoIterator for &'a Certificates {
type Item = &'a Certificate;
type IntoIter = std::slice::Iter<'a, Certificate>;

fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ pub use self::filters::compression;
#[cfg(feature = "multipart")]
#[doc(hidden)]
pub use self::filters::multipart;
#[cfg(feature = "tls")]
#[doc(hidden)]
pub use self::filters::mtls;
#[cfg(feature = "websocket")]
#[doc(hidden)]
pub use self::filters::ws;
Expand Down
14 changes: 10 additions & 4 deletions src/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::net::SocketAddr;
use hyper::Body;

use crate::Request;
use crate::transport::PeerInfo;

scoped_thread_local!(static ROUTE: RefCell<Route>);

Expand All @@ -30,7 +31,7 @@ where
#[derive(Debug)]
pub(crate) struct Route {
body: BodyState,
remote_addr: Option<SocketAddr>,
peer_info: PeerInfo,
req: Request,
segments_index: usize,
}
Expand All @@ -42,7 +43,7 @@ enum BodyState {
}

impl Route {
pub(crate) fn new(req: Request, remote_addr: Option<SocketAddr>) -> RefCell<Route> {
pub(crate) fn new(req: Request, peer_info: PeerInfo) -> RefCell<Route> {
let segments_index = if req.uri().path().starts_with('/') {
// Skip the beginning slash.
1
Expand All @@ -52,7 +53,7 @@ impl Route {

RefCell::new(Route {
body: BodyState::Ready,
remote_addr,
peer_info,
req,
segments_index,
})
Expand Down Expand Up @@ -124,7 +125,12 @@ impl Route {
}

pub(crate) fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_addr
self.peer_info.remote_addr
}

#[cfg(feature = "tls")]
pub(crate) fn peer_certificates(&self) -> crate::transport::PeerCertificates {
self.peer_info.peer_certificates.clone()
}

pub(crate) fn take_body(&mut self) -> Option<Body> {
Expand Down
9 changes: 7 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,14 @@ macro_rules! into_service {
let inner = crate::service($into);
make_service_fn(move |transport| {
let inner = inner.clone();
let remote_addr = Transport::remote_addr(transport);

let peer_info = crate::transport::PeerInfo {
remote_addr: Transport::remote_addr(transport),
peer_certificates: Transport::peer_certificates(transport),
};

future::ok::<_, Infallible>(service_fn(move |req| {
inner.call_with_addr(req, remote_addr)
inner.call_with_peer_info(req, peer_info.clone())
}))
})
}};
Expand Down
30 changes: 21 additions & 9 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,15 @@ use crate::reject::IsReject;
use crate::reply::Reply;
use crate::route::{self, Route};
use crate::Request;
use crate::transport::PeerInfo;
#[cfg(feature = "websocket")]
use crate::{Sink, Stream};

use self::inner::OneOrTuple;

/// Starts a new test `RequestBuilder`.
pub fn request() -> RequestBuilder {
RequestBuilder {
remote_addr: None,
req: Request::default(),
}
Default::default()
}

/// Starts a new test `WsBuilder`.
Expand All @@ -137,9 +135,9 @@ pub fn ws() -> WsBuilder {
///
/// See [module documentation](crate::test) for an overview.
#[must_use = "RequestBuilder does nothing on its own"]
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct RequestBuilder {
remote_addr: Option<SocketAddr>,
peer_info: PeerInfo,
req: Request,
}

Expand Down Expand Up @@ -248,7 +246,21 @@ impl RequestBuilder {
/// .remote_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
/// ```
pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
self.remote_addr = Some(addr);
self.peer_info.remote_addr = Some(addr);
self
}

/// Set the peer certificates of this request.
/// Default is no peer certificates.
///
/// # Example
/// ```
/// let req = warp::test::request()
/// .peer_certificates([tokio_rustls::rustls::Certificate(b"FAKE CERT".to_vec())]);
/// ```
#[cfg(feature = "tls")]
pub fn peer_certificates(self, certs: impl Into<Vec<tokio_rustls::rustls::Certificate>>) -> Self {
*self.peer_info.peer_certificates.write().unwrap() = Some(certs.into());
self
}

Expand Down Expand Up @@ -375,7 +387,7 @@ impl RequestBuilder {
// TODO: de-duplicate this and apply_filter()
assert!(!route::is_set(), "nested test filter calls");

let route = Route::new(self.req, self.remote_addr);
let route = Route::new(self.req, self.peer_info);
let mut fut = Box::pin(
route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| {
let res = match result {
Expand Down Expand Up @@ -404,7 +416,7 @@ impl RequestBuilder {
{
assert!(!route::is_set(), "nested test filter calls");

let route = Route::new(self.req, self.remote_addr);
let route = Route::new(self.req, self.peer_info);
let mut fut = Box::pin(route::set(&route, move || {
f.filter(crate::filter::Internal)
}));
Expand Down
12 changes: 11 additions & 1 deletion src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use hyper::server::conn::{AddrIncoming, AddrStream};
use tokio_rustls::rustls::server::WebPkiClientVerifier;
use tokio_rustls::rustls::{Error as TlsError, RootCertStore, ServerConfig};

use crate::transport::Transport;
use crate::transport::{PeerCertificates, Transport};

/// Represents errors that can occur building the TlsConfig
#[derive(Debug)]
Expand Down Expand Up @@ -284,6 +284,10 @@ impl Transport for TlsStream {
fn remote_addr(&self) -> Option<SocketAddr> {
Some(self.remote_addr)
}

fn peer_certificates(&self) -> PeerCertificates {
self.peer_certs.clone()
}
}

enum State {
Expand All @@ -297,6 +301,7 @@ enum State {
pub(crate) struct TlsStream {
state: State,
remote_addr: SocketAddr,
peer_certs: PeerCertificates,
}

impl TlsStream {
Expand All @@ -306,6 +311,7 @@ impl TlsStream {
TlsStream {
state: State::Handshaking(accept),
remote_addr,
peer_certs: Default::default(),
}
}
}
Expand All @@ -320,6 +326,10 @@ impl AsyncRead for TlsStream {
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let (_, conn) = stream.get_ref();
*pin.peer_certs.write().unwrap() =
conn.peer_certificates().map(|certs| certs.to_vec());

let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
Expand Down
16 changes: 16 additions & 0 deletions src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,24 @@ use std::task::{Context, Poll};
use hyper::server::conn::AddrStream;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

#[cfg(feature = "tls")]
pub(crate) type PeerCertificates = std::sync::Arc<std::sync::RwLock<Option<Vec<tokio_rustls::rustls::Certificate>>>>;
#[cfg(not(feature = "tls"))]
pub(crate) type PeerCertificates = ();

pub trait Transport: AsyncRead + AsyncWrite {
fn remote_addr(&self) -> Option<SocketAddr>;

fn peer_certificates(&self) -> PeerCertificates {
Default::default()
}
}

#[derive(Clone, Debug, Default)]
pub(crate) struct PeerInfo {
pub remote_addr: Option<SocketAddr>,
#[allow(dead_code)]
pub peer_certificates: PeerCertificates,
}

impl Transport for AddrStream {
Expand Down
26 changes: 26 additions & 0 deletions tests/mtls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#![deny(warnings)]
#![cfg(feature = "tls")]

use tokio_rustls::rustls::Certificate;

#[tokio::test]
async fn peer_certificates_missing() {
let extract_peer_certs = warp::mtls::peer_certificates();

let req = warp::test::request();
let resp = req.filter(&extract_peer_certs).await.unwrap();
assert!(resp.is_none())
}

#[tokio::test]
async fn peer_certificates_present() {
let extract_peer_certs = warp::mtls::peer_certificates();

let cert = Certificate(b"TEST CERT".to_vec());
let req = warp::test::request().peer_certificates([cert.clone()]);
let resp = req.filter(&extract_peer_certs).await.unwrap();
assert_eq!(
resp.unwrap().as_ref(),
&[cert],
)
}

0 comments on commit a2defac

Please sign in to comment.