Skip to content

Commit

Permalink
Updating peer certificate logic to work with modern rustls
Browse files Browse the repository at this point in the history
  • Loading branch information
joelweinberger committed Oct 15, 2024
1 parent a2defac commit 740f17a
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 41 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
Cargo.lock
.idea/
warp.iml

*.swp
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ percent-encoding = "2.1"
pin-project = "1.0"
tokio-rustls = { version = "0.25", optional = true }
rustls-pemfile = { version = "2.0", optional = true }
rustls-pki-types = { version = "1.9.0", optional = true }

[dev-dependencies]
pretty_env_logger = "0.5"
Expand All @@ -56,7 +57,7 @@ listenfd = "1.0"
default = ["multipart", "websocket"]
multipart = ["multer"]
websocket = ["tokio-tungstenite"]
tls = ["tokio-rustls", "rustls-pemfile"]
tls = ["tokio-rustls", "rustls-pemfile", "rustls-pki-types"]

# Enable compression-related filters
compression = ["compression-brotli", "compression-gzip"]
Expand Down
47 changes: 19 additions & 28 deletions src/filters/mtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
use std::convert::Infallible;

use tokio_rustls::rustls::Certificate;
use rustls_pki_types::CertificateDer;

use crate::{
filter::{filter_fn_one, Filter},
route::Route,
};

/// Certificates is a iterable container of Certificates.
pub type Certificates = Vec<CertificateDer<'static>>;

/// Creates a `Filter` to get the peer certificates for the TLS connection.
///
/// If the underlying transport doesn't have peer certificates, this will yield
Expand All @@ -27,35 +30,23 @@ use crate::{
/// ```
pub fn peer_certificates(
) -> impl Filter<Extract = (Option<Certificates>,), Error = Infallible> + Copy {
filter_fn_one(|route| futures_util::future::ok(Certificates::from_route(route)))
filter_fn_one(|route| futures_util::future::ok(from_route(route)))
}

/// Certificates is a iterable container of Certificates.
#[derive(Debug)]
pub struct Certificates(Vec<Certificate>);

impl Certificates {
fn from_route(route: &Route) -> Option<Certificates> {
route
.peer_certificates()
.read()
.unwrap()
.as_ref()
.map(|certs| Self(certs.to_vec()))
}
/// Testing
pub fn peer_certs_into_owned(certs: &Vec<CertificateDer<'_>>) -> Vec<CertificateDer<'static>> {
certs
.to_vec()
.iter()
.map(|cert| cert.clone().into_owned())
.collect()
}

impl AsRef<[Certificate]> for Certificates {
fn as_ref(&self) -> &[Certificate] {
self.0.as_ref()
}
}

impl<'a> IntoIterator for &'a Certificates {
type Item = &'a Certificate;
type IntoIter = std::slice::Iter<'a, Certificate>;

fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
fn from_route(route: &Route) -> Option<Certificates> {
route
.peer_certificates()
.read()
.unwrap()
.as_ref()
.map(peer_certs_into_owned)
}
9 changes: 6 additions & 3 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,14 @@ use crate::filters::ws::Message;
use crate::reject::IsReject;
use crate::reply::Reply;
use crate::route::{self, Route};
use crate::Request;
use crate::transport::PeerInfo;
use crate::Request;
#[cfg(feature = "websocket")]
use crate::{Sink, Stream};

#[cfg(feature = "tls")]
use crate::filters::mtls::Certificates;

use self::inner::OneOrTuple;

/// Starts a new test `RequestBuilder`.
Expand Down Expand Up @@ -256,10 +259,10 @@ impl RequestBuilder {
/// # Example
/// ```
/// let req = warp::test::request()
/// .peer_certificates([tokio_rustls::rustls::Certificate(b"FAKE CERT".to_vec())]);
/// .peer_certificates([rustls_pki_types::CertificateDer::from_slice(b"FAKE CERT")]);
/// ```
#[cfg(feature = "tls")]
pub fn peer_certificates(self, certs: impl Into<Vec<tokio_rustls::rustls::Certificate>>) -> Self {
pub fn peer_certificates(self, certs: impl Into<Certificates>) -> Self {
*self.peer_info.peer_certificates.write().unwrap() = Some(certs.into());
self
}
Expand Down
6 changes: 4 additions & 2 deletions src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use hyper::server::conn::{AddrIncoming, AddrStream};
use tokio_rustls::rustls::server::WebPkiClientVerifier;
use tokio_rustls::rustls::{Error as TlsError, RootCertStore, ServerConfig};

use crate::filters::mtls::peer_certs_into_owned;
use crate::transport::{PeerCertificates, Transport};

/// Represents errors that can occur building the TlsConfig
Expand Down Expand Up @@ -327,8 +328,9 @@ impl AsyncRead for TlsStream {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let (_, conn) = stream.get_ref();
*pin.peer_certs.write().unwrap() =
conn.peer_certificates().map(|certs| certs.to_vec());
*pin.peer_certs.write().unwrap() = conn
.peer_certificates()
.map(|certs| peer_certs_into_owned(&certs.to_vec()));

let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
Expand Down
5 changes: 4 additions & 1 deletion src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use hyper::server::conn::AddrStream;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

#[cfg(feature = "tls")]
pub(crate) type PeerCertificates = std::sync::Arc<std::sync::RwLock<Option<Vec<tokio_rustls::rustls::Certificate>>>>;
use crate::filters::mtls::Certificates;

#[cfg(feature = "tls")]
pub(crate) type PeerCertificates = std::sync::Arc<std::sync::RwLock<Option<Certificates>>>;
#[cfg(not(feature = "tls"))]
pub(crate) type PeerCertificates = ();

Expand Down
10 changes: 4 additions & 6 deletions tests/mtls.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![deny(warnings)]
#![cfg(feature = "tls")]

use tokio_rustls::rustls::Certificate;
use rustls_pki_types::CertificateDer;

#[tokio::test]
async fn peer_certificates_missing() {
Expand All @@ -16,11 +16,9 @@ async fn peer_certificates_missing() {
async fn peer_certificates_present() {
let extract_peer_certs = warp::mtls::peer_certificates();

let cert = Certificate(b"TEST CERT".to_vec());
let cert = CertificateDer::<'_>::from_slice(b"TEST CERT");

let req = warp::test::request().peer_certificates([cert.clone()]);
let resp = req.filter(&extract_peer_certs).await.unwrap();
assert_eq!(
resp.unwrap().as_ref(),
&[cert],
)
assert_eq!(resp.unwrap(), &[cert],)
}

0 comments on commit 740f17a

Please sign in to comment.