Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Acceptor improvements #219

Merged
merged 13 commits into from
Aug 14, 2023
9 changes: 7 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ name: rustls
permissions:
contents: read

on: [push, pull_request]
on:
push:
pull_request:
merge_group:
schedule:
- cron: '23 6 * * 5'

jobs:
build:
Expand Down Expand Up @@ -78,7 +83,7 @@ jobs:
- name: Install rust toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: "1.60"
toolchain: "1.63"
cpu marked this conversation as resolved.
Show resolved Hide resolved

- name: Check MSRV
run: cargo check --lib --all-features
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "hyper-rustls"
version = "0.24.1"
edition = "2021"
rust-version = "1.60"
rust-version = "1.63"
license = "Apache-2.0 OR ISC OR MIT"
readme = "README.md"
description = "Rustls+hyper integration for pure rust HTTPS"
Expand All @@ -15,7 +15,7 @@ http = "0.2"
hyper = { version = "0.14", default-features = false, features = ["client"] }
log = { version = "0.4.4", optional = true }
rustls-native-certs = { version = "0.6", optional = true }
rustls = { version = "0.21.0", default-features = false }
rustls = { version = "0.21.6", default-features = false }
tokio = "1.0"
tokio-rustls = { version = "0.24.0", default-features = false }
webpki-roots = { version = "0.25", optional = true }
Expand Down
176 changes: 101 additions & 75 deletions src/acceptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,60 @@ use hyper::server::{
accept::Accept,
conn::{AddrIncoming, AddrStream},
};
use rustls::ServerConfig;
use rustls::{ServerConfig, ServerConnection};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

mod builder;
pub use builder::AcceptorBuilder;
use builder::WantsTlsConfig;

enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
/// A TLS acceptor that can be used with hyper servers.
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}

/// An Acceptor for the `https` scheme.
impl TlsAcceptor {
/// Provides a builder for a `TlsAcceptor`.
pub fn builder() -> AcceptorBuilder<WantsTlsConfig> {
AcceptorBuilder::new()
}

/// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`.
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> Self {
Self { config, incoming }
}
}

impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
Poll::Ready(match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Some(Ok(TlsStream::new(sock, pin.config.clone()))),
Some(Err(e)) => Some(Err(e)),
None => None,
})
}
}

impl<C, I> From<(C, I)> for TlsAcceptor
where
C: Into<Arc<ServerConfig>>,
I: Into<AddrIncoming>,
{
fn from((config, incoming): (C, I)) -> Self {
Self::new(config.into(), incoming.into())
}
}

/// A TLS stream constructed by a [`TlsAcceptor`].
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite by handshaking with tokio_rustls::Accept first
Expand All @@ -29,12 +71,32 @@ pub struct TlsStream {
}

impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> Self {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
Self {
state: State::Handshaking(accept),
}
}

/// Returns a reference to the underlying IO stream.
///
/// This should always return `Some`, except if an error has already been yielded.
pub fn io(&self) -> Option<&AddrStream> {
match &self.state {
State::Handshaking(accept) => accept.get_ref(),
State::Streaming(stream) => Some(stream.get_ref().0),
}
}

/// Returns a reference to the underlying [`rustls::ServerConnection'].
///
/// This will start yielding `Some` only after the handshake has completed.
pub fn connection(&self) -> Option<&ServerConnection> {
match &self.state {
State::Handshaking(_) => None,
State::Streaming(stream) => Some(stream.get_ref().1),
}
}
}

impl AsyncRead for TlsStream {
Expand All @@ -44,17 +106,19 @@ impl AsyncRead for TlsStream {
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
let accept = match &mut pin.state {
State::Handshaking(accept) => accept,
State::Streaming(stream) => return Pin::new(stream).poll_read(cx, buf),
};

let mut stream = match ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => stream,
Err(err) => return Poll::Ready(Err(err)),
};

let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
}

Expand All @@ -65,75 +129,37 @@ impl AsyncWrite for TlsStream {
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
let accept = match &mut pin.state {
State::Handshaking(accept) => accept,
State::Streaming(stream) => return Pin::new(stream).poll_write(cx, buf),
};

let mut stream = match ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => stream,
Err(err) => return Poll::Ready(Err(err)),
};

let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
match &mut self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
State::Streaming(stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
match &mut self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
State::Streaming(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

/// A TLS acceptor that can be used with hyper servers.
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}

/// An Acceptor for the `https` scheme.
impl TlsAcceptor {
/// Provides a builder for a `TlsAcceptor`.
pub fn builder() -> AcceptorBuilder<WantsTlsConfig> {
AcceptorBuilder::new()
}
/// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`.
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}

impl<C, I> From<(C, I)> for TlsAcceptor
where
C: Into<Arc<ServerConfig>>,
I: Into<AddrIncoming>,
{
fn from((config, incoming): (C, I)) -> TlsAcceptor {
TlsAcceptor::new(config.into(), incoming.into())
}
}

impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl ConfigBuilderExt for ConfigBuilder<ClientConfig, WantsVerifier> {
#[cfg_attr(docsrs, doc(cfg(feature = "webpki-roots")))]
fn with_webpki_roots(self) -> ConfigBuilder<ClientConfig, WantsTransparencyPolicyOrClientCert> {
let mut roots = rustls::RootCertStore::empty();
roots.add_server_trust_anchors(
roots.add_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS
.iter()
.map(|ta| {
Expand Down
4 changes: 2 additions & 2 deletions src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use tokio_rustls::TlsConnector;

use crate::stream::MaybeHttpsStream;

pub mod builder;
pub(crate) mod builder;

type BoxError = Box<dyn std::error::Error + Send + Sync>;

Expand Down Expand Up @@ -45,7 +45,7 @@ where
C: Into<Arc<rustls::ClientConfig>>,
{
fn from((http, cfg): (H, C)) -> Self {
HttpsConnector {
Self {
force_https: false,
http,
tls_config: cfg.into(),
Expand Down
7 changes: 4 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,19 @@
//! # fn main() {}
//! ```

#![warn(missing_docs)]
#![warn(missing_docs, unreachable_pub, clippy::use_self)]
#![cfg_attr(docsrs, feature(doc_cfg))]

#[cfg(feature = "acceptor")]
mod acceptor;
/// TLS acceptor implementing hyper's `Accept` trait.
pub mod acceptor;
mod config;
mod connector;
mod stream;

#[cfg(feature = "logging")]
mod log {
pub use log::{debug, trace};
pub(crate) use log::{debug, trace};
}

#[cfg(not(feature = "logging"))]
Expand Down
28 changes: 14 additions & 14 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ pub enum MaybeHttpsStream<T> {
impl<T: AsyncRead + AsyncWrite + Connection + Unpin> Connection for MaybeHttpsStream<T> {
fn connected(&self) -> Connected {
match self {
MaybeHttpsStream::Http(s) => s.connected(),
MaybeHttpsStream::Https(s) => {
Self::Http(s) => s.connected(),
Self::Https(s) => {
let (tcp, tls) = s.get_ref();
if tls.alpn_protocol() == Some(b"h2") {
tcp.connected().negotiated_h2()
Expand All @@ -37,21 +37,21 @@ impl<T: AsyncRead + AsyncWrite + Connection + Unpin> Connection for MaybeHttpsSt
impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
MaybeHttpsStream::Http(..) => f.pad("Http(..)"),
MaybeHttpsStream::Https(..) => f.pad("Https(..)"),
Self::Http(..) => f.pad("Http(..)"),
Self::Https(..) => f.pad("Https(..)"),
}
}
}

impl<T> From<T> for MaybeHttpsStream<T> {
fn from(inner: T) -> Self {
MaybeHttpsStream::Http(inner)
Self::Http(inner)
}
}

impl<T> From<TlsStream<T>> for MaybeHttpsStream<T> {
fn from(inner: TlsStream<T>) -> Self {
MaybeHttpsStream::Https(inner)
Self::Https(inner)
}
}

Expand All @@ -63,8 +63,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeHttpsStream<T> {
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), io::Error>> {
match Pin::get_mut(self) {
MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf),
MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(cx, buf),
Self::Http(s) => Pin::new(s).poll_read(cx, buf),
Self::Https(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
Expand All @@ -77,24 +77,24 @@ impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for MaybeHttpsStream<T> {
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
match Pin::get_mut(self) {
MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(cx, buf),
MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(cx, buf),
Self::Http(s) => Pin::new(s).poll_write(cx, buf),
Self::Https(s) => Pin::new(s).poll_write(cx, buf),
}
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match Pin::get_mut(self) {
MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(cx),
MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(cx),
Self::Http(s) => Pin::new(s).poll_flush(cx),
Self::Https(s) => Pin::new(s).poll_flush(cx),
}
}

#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match Pin::get_mut(self) {
MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(cx),
MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(cx),
Self::Http(s) => Pin::new(s).poll_shutdown(cx),
Self::Https(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
Loading