diff --git a/examples/http_proxy.rs b/examples/http_proxy.rs
index 39b2a4644d..a08bdc4a29 100644
--- a/examples/http_proxy.rs
+++ b/examples/http_proxy.rs
@@ -58,7 +58,7 @@ async fn proxy(client: HttpClient, req: Request
) -> Result,
// `on_upgrade` future.
if let Some(addr) = host_addr(req.uri()) {
tokio::task::spawn(async move {
- match req.into_body().on_upgrade().await {
+ match hyper::upgrade::on(req).await {
Ok(upgraded) => {
if let Err(e) = tunnel(upgraded, addr).await {
eprintln!("server io error: {}", e);
diff --git a/examples/upgrades.rs b/examples/upgrades.rs
index 0bc75444a4..38cfded386 100644
--- a/examples/upgrades.rs
+++ b/examples/upgrades.rs
@@ -34,7 +34,7 @@ async fn server_upgraded_io(mut upgraded: Upgraded) -> Result<()> {
}
/// Our server HTTP handler to initiate HTTP upgrades.
-async fn server_upgrade(req: Request) -> Result> {
+async fn server_upgrade(mut req: Request) -> Result> {
let mut res = Response::new(Body::empty());
// Send a 400 to any request that doesn't have
@@ -52,7 +52,7 @@ async fn server_upgrade(req: Request) -> Result> {
// is returned below, so it's better to spawn this future instead
// waiting for it to complete to then return a response.
tokio::task::spawn(async move {
- match req.into_body().on_upgrade().await {
+ match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
if let Err(e) = server_upgraded_io(upgraded).await {
eprintln!("server foobar io error: {}", e)
@@ -97,7 +97,7 @@ async fn client_upgrade_request(addr: SocketAddr) -> Result<()> {
panic!("Our server didn't upgrade: {}", res.status());
}
- match res.into_body().on_upgrade().await {
+ match hyper::upgrade::on(res).await {
Ok(upgraded) => {
if let Err(e) = client_upgraded_io(upgraded).await {
eprintln!("client foobar io error: {}", e)
diff --git a/src/body/body.rs b/src/body/body.rs
index e2e216f87d..fca66a6f93 100644
--- a/src/body/body.rs
+++ b/src/body/body.rs
@@ -187,13 +187,15 @@ impl Body {
Body::new(Kind::Wrapped(SyncWrapper::new(Box::pin(mapped))))
}
- /// Converts this `Body` into a `Future` of a pending HTTP upgrade.
- ///
- /// See [the `upgrade` module](crate::upgrade) for more.
- pub fn on_upgrade(self) -> OnUpgrade {
- self.extra
- .map(|ex| ex.on_upgrade)
- .unwrap_or_else(OnUpgrade::none)
+ // TODO: Eventually the pending upgrade should be stored in the
+ // `Extensions`, and all these pieces can be removed. In v0.14, we made
+ // the breaking changes, so now this TODO can be done without breakage.
+ pub(crate) fn take_upgrade(&mut self) -> OnUpgrade {
+ if let Some(ref mut extra) = self.extra {
+ std::mem::replace(&mut extra.on_upgrade, OnUpgrade::none())
+ } else {
+ OnUpgrade::none()
+ }
}
fn new(kind: Kind) -> Body {
diff --git a/src/upgrade.rs b/src/upgrade.rs
index 7e51bf22ce..4f377b8c4a 100644
--- a/src/upgrade.rs
+++ b/src/upgrade.rs
@@ -57,18 +57,16 @@ pub struct Parts {
_inner: (),
}
+/// Gets a pending HTTP upgrade from this message.
+pub fn on(msg: T) -> OnUpgrade {
+ msg.on_upgrade()
+}
+
#[cfg(feature = "http1")]
pub(crate) struct Pending {
tx: oneshot::Sender>,
}
-/// Error cause returned when an upgrade was expected but canceled
-/// for whatever reason.
-///
-/// This likely means the actual `Conn` future wasn't polled and upgraded.
-#[derive(Debug)]
-struct UpgradeExpected(());
-
#[cfg(feature = "http1")]
pub(crate) fn pending() -> (Pending, OnUpgrade) {
let (tx, rx) = oneshot::channel();
@@ -162,9 +160,7 @@ impl Future for OnUpgrade {
Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res {
Ok(Ok(upgraded)) => Ok(upgraded),
Ok(Err(err)) => Err(err),
- Err(_oneshot_canceled) => {
- Err(crate::Error::new_canceled().with(UpgradeExpected(())))
- }
+ Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)),
}),
None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
}
@@ -196,9 +192,16 @@ impl Pending {
// ===== impl UpgradeExpected =====
+/// Error cause returned when an upgrade was expected but canceled
+/// for whatever reason.
+///
+/// This likely means the actual `Conn` future wasn't polled and upgraded.
+#[derive(Debug)]
+struct UpgradeExpected;
+
impl fmt::Display for UpgradeExpected {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "upgrade expected but not completed")
+ f.write_str("upgrade expected but not completed")
}
}
@@ -277,6 +280,38 @@ impl Io for ForwardsWriteBuf {
}
}
+mod sealed {
+ use super::OnUpgrade;
+
+ pub trait CanUpgrade {
+ fn on_upgrade(self) -> OnUpgrade;
+ }
+
+ impl CanUpgrade for http::Request {
+ fn on_upgrade(self) -> OnUpgrade {
+ self.into_body().take_upgrade()
+ }
+ }
+
+ impl CanUpgrade for &'_ mut http::Request {
+ fn on_upgrade(self) -> OnUpgrade {
+ self.body_mut().take_upgrade()
+ }
+ }
+
+ impl CanUpgrade for http::Response {
+ fn on_upgrade(self) -> OnUpgrade {
+ self.into_body().take_upgrade()
+ }
+ }
+
+ impl CanUpgrade for &'_ mut http::Response {
+ fn on_upgrade(self) -> OnUpgrade {
+ self.body_mut().take_upgrade()
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/tests/client.rs b/tests/client.rs
index d3a91aae0d..d334d06500 100644
--- a/tests/client.rs
+++ b/tests/client.rs
@@ -1791,7 +1791,7 @@ mod dispatch_impl {
assert_eq!(res.status(), 101);
let upgraded = rt
- .block_on(res.into_body().on_upgrade())
+ .block_on(hyper::upgrade::on(res))
.expect("on_upgrade");
let parts = upgraded.downcast::().unwrap();
diff --git a/tests/server.rs b/tests/server.rs
index 5ecd21a826..72d2e459d8 100644
--- a/tests/server.rs
+++ b/tests/server.rs
@@ -1341,7 +1341,7 @@ async fn upgrades_new() {
let (upgrades_tx, upgrades_rx) = mpsc::channel();
let svc = service_fn(move |req: Request| {
- let on_upgrade = req.into_body().on_upgrade();
+ let on_upgrade = hyper::upgrade::on(req);
let _ = upgrades_tx.send(on_upgrade);
future::ok::<_, hyper::Error>(
Response::builder()
@@ -1448,7 +1448,7 @@ async fn http_connect_new() {
let (upgrades_tx, upgrades_rx) = mpsc::channel();
let svc = service_fn(move |req: Request| {
- let on_upgrade = req.into_body().on_upgrade();
+ let on_upgrade = hyper::upgrade::on(req);
let _ = upgrades_tx.send(on_upgrade);
future::ok::<_, hyper::Error>(
Response::builder()