diff --git a/examples/http_proxy.rs b/examples/http_proxy.rs index 39b2a4644d..a08bdc4a29 100644 --- a/examples/http_proxy.rs +++ b/examples/http_proxy.rs @@ -58,7 +58,7 @@ async fn proxy(client: HttpClient, req: Request) -> Result, // `on_upgrade` future. if let Some(addr) = host_addr(req.uri()) { tokio::task::spawn(async move { - match req.into_body().on_upgrade().await { + match hyper::upgrade::on(req).await { Ok(upgraded) => { if let Err(e) = tunnel(upgraded, addr).await { eprintln!("server io error: {}", e); diff --git a/examples/upgrades.rs b/examples/upgrades.rs index 0bc75444a4..38cfded386 100644 --- a/examples/upgrades.rs +++ b/examples/upgrades.rs @@ -34,7 +34,7 @@ async fn server_upgraded_io(mut upgraded: Upgraded) -> Result<()> { } /// Our server HTTP handler to initiate HTTP upgrades. -async fn server_upgrade(req: Request) -> Result> { +async fn server_upgrade(mut req: Request) -> Result> { let mut res = Response::new(Body::empty()); // Send a 400 to any request that doesn't have @@ -52,7 +52,7 @@ async fn server_upgrade(req: Request) -> Result> { // is returned below, so it's better to spawn this future instead // waiting for it to complete to then return a response. tokio::task::spawn(async move { - match req.into_body().on_upgrade().await { + match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { if let Err(e) = server_upgraded_io(upgraded).await { eprintln!("server foobar io error: {}", e) @@ -97,7 +97,7 @@ async fn client_upgrade_request(addr: SocketAddr) -> Result<()> { panic!("Our server didn't upgrade: {}", res.status()); } - match res.into_body().on_upgrade().await { + match hyper::upgrade::on(res).await { Ok(upgraded) => { if let Err(e) = client_upgraded_io(upgraded).await { eprintln!("client foobar io error: {}", e) diff --git a/src/body/body.rs b/src/body/body.rs index e2e216f87d..fca66a6f93 100644 --- a/src/body/body.rs +++ b/src/body/body.rs @@ -187,13 +187,15 @@ impl Body { Body::new(Kind::Wrapped(SyncWrapper::new(Box::pin(mapped)))) } - /// Converts this `Body` into a `Future` of a pending HTTP upgrade. - /// - /// See [the `upgrade` module](crate::upgrade) for more. - pub fn on_upgrade(self) -> OnUpgrade { - self.extra - .map(|ex| ex.on_upgrade) - .unwrap_or_else(OnUpgrade::none) + // TODO: Eventually the pending upgrade should be stored in the + // `Extensions`, and all these pieces can be removed. In v0.14, we made + // the breaking changes, so now this TODO can be done without breakage. + pub(crate) fn take_upgrade(&mut self) -> OnUpgrade { + if let Some(ref mut extra) = self.extra { + std::mem::replace(&mut extra.on_upgrade, OnUpgrade::none()) + } else { + OnUpgrade::none() + } } fn new(kind: Kind) -> Body { diff --git a/src/upgrade.rs b/src/upgrade.rs index 7e51bf22ce..4f377b8c4a 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -57,18 +57,16 @@ pub struct Parts { _inner: (), } +/// Gets a pending HTTP upgrade from this message. +pub fn on(msg: T) -> OnUpgrade { + msg.on_upgrade() +} + #[cfg(feature = "http1")] pub(crate) struct Pending { tx: oneshot::Sender>, } -/// Error cause returned when an upgrade was expected but canceled -/// for whatever reason. -/// -/// This likely means the actual `Conn` future wasn't polled and upgraded. -#[derive(Debug)] -struct UpgradeExpected(()); - #[cfg(feature = "http1")] pub(crate) fn pending() -> (Pending, OnUpgrade) { let (tx, rx) = oneshot::channel(); @@ -162,9 +160,7 @@ impl Future for OnUpgrade { Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res { Ok(Ok(upgraded)) => Ok(upgraded), Ok(Err(err)) => Err(err), - Err(_oneshot_canceled) => { - Err(crate::Error::new_canceled().with(UpgradeExpected(()))) - } + Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)), }), None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())), } @@ -196,9 +192,16 @@ impl Pending { // ===== impl UpgradeExpected ===== +/// Error cause returned when an upgrade was expected but canceled +/// for whatever reason. +/// +/// This likely means the actual `Conn` future wasn't polled and upgraded. +#[derive(Debug)] +struct UpgradeExpected; + impl fmt::Display for UpgradeExpected { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "upgrade expected but not completed") + f.write_str("upgrade expected but not completed") } } @@ -277,6 +280,38 @@ impl Io for ForwardsWriteBuf { } } +mod sealed { + use super::OnUpgrade; + + pub trait CanUpgrade { + fn on_upgrade(self) -> OnUpgrade; + } + + impl CanUpgrade for http::Request { + fn on_upgrade(self) -> OnUpgrade { + self.into_body().take_upgrade() + } + } + + impl CanUpgrade for &'_ mut http::Request { + fn on_upgrade(self) -> OnUpgrade { + self.body_mut().take_upgrade() + } + } + + impl CanUpgrade for http::Response { + fn on_upgrade(self) -> OnUpgrade { + self.into_body().take_upgrade() + } + } + + impl CanUpgrade for &'_ mut http::Response { + fn on_upgrade(self) -> OnUpgrade { + self.body_mut().take_upgrade() + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/client.rs b/tests/client.rs index d3a91aae0d..d334d06500 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1791,7 +1791,7 @@ mod dispatch_impl { assert_eq!(res.status(), 101); let upgraded = rt - .block_on(res.into_body().on_upgrade()) + .block_on(hyper::upgrade::on(res)) .expect("on_upgrade"); let parts = upgraded.downcast::().unwrap(); diff --git a/tests/server.rs b/tests/server.rs index 5ecd21a826..72d2e459d8 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -1341,7 +1341,7 @@ async fn upgrades_new() { let (upgrades_tx, upgrades_rx) = mpsc::channel(); let svc = service_fn(move |req: Request| { - let on_upgrade = req.into_body().on_upgrade(); + let on_upgrade = hyper::upgrade::on(req); let _ = upgrades_tx.send(on_upgrade); future::ok::<_, hyper::Error>( Response::builder() @@ -1448,7 +1448,7 @@ async fn http_connect_new() { let (upgrades_tx, upgrades_rx) = mpsc::channel(); let svc = service_fn(move |req: Request| { - let on_upgrade = req.into_body().on_upgrade(); + let on_upgrade = hyper::upgrade::on(req); let _ = upgrades_tx.send(on_upgrade); future::ok::<_, hyper::Error>( Response::builder()