Skip to content

Commit

Permalink
Fix early-data wakeup loss
Browse files Browse the repository at this point in the history
  • Loading branch information
quininer committed Sep 28, 2021
1 parent db01bce commit 46cb00f
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 28 deletions.
2 changes: 1 addition & 1 deletion tokio-rustls/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ See [examples/server](examples/server/src/main.rs). You can run it with:

```sh
cd examples/server
cargo run -- 127.0.0.1 --cert mycert.der --key mykey.der
cargo run -- 127.0.0.1:8000 --cert mycert.der --key mykey.der
```

### License & Origin
Expand Down
39 changes: 30 additions & 9 deletions tokio-rustls/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use std::task::Waker;
use crate::common::IoSession;
use rustls::Session;

Expand All @@ -9,6 +10,7 @@ pub struct TlsStream<IO> {
pub(crate) io: IO,
pub(crate) session: ClientSession,
pub(crate) state: TlsState,
pub(crate) early_waker: Option<Waker>
}

impl<IO> TlsStream<IO> {
Expand Down Expand Up @@ -59,7 +61,17 @@ where
) -> Poll<io::Result<()>> {
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(..) => Poll::Pending,
TlsState::EarlyData(..) => {
let this = self.get_mut();
if this.early_waker.as_ref()
.filter(|waker| cx.waker().will_wake(waker))
.is_none()
{
this.early_waker = Some(cx.waker().clone());
}

Poll::Pending
},
TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut();
let mut stream =
Expand Down Expand Up @@ -137,6 +149,11 @@ where

// end
this.state = TlsState::Stream;

if let Some(waker) = this.early_waker.take() {
waker.wake();
}

stream.as_mut_pin().poll_write(cx, buf)
}
_ => stream.as_mut_pin().poll_write(cx, buf),
Expand Down Expand Up @@ -165,26 +182,30 @@ where
}

this.state = TlsState::Stream;

if let Some(waker) = this.early_waker.take() {
waker.wake();
}
}
}

stream.as_mut_pin().poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.state.writeable() {
self.session.send_close_notify();
self.state.shutdown_write();
}

#[cfg(feature = "early-data")]
{
// we skip the handshake
if let TlsState::EarlyData(..) = self.state {
return Pin::new(&mut self.io).poll_shutdown(cx);
// complete handshake
if matches!(self.state, TlsState::EarlyData(..)) {
ready!(self.as_mut().poll_flush(cx))?;
}
}

if self.state.writeable() {
self.session.send_close_notify();
self.state.shutdown_write();
}

let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
Expand Down
3 changes: 3 additions & 0 deletions tokio-rustls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ impl TlsConnector {
TlsState::Stream
},

#[cfg(feature = "early-data")]
early_waker: None,

session,
}))
}
Expand Down
62 changes: 44 additions & 18 deletions tokio-rustls/tests/early-data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use std::process::{Child, Command, Stdio};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf};
use tokio::io::{split, AsyncRead, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream;
use tokio::time::sleep;
use tokio::sync::oneshot;
use tokio_rustls::{client::TlsStream, TlsConnector};

struct Read1<T>(T);
Expand All @@ -21,9 +22,15 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut buf = [0];
let mut buf = &mut ReadBuf::new(&mut buf);
let mut buf = ReadBuf::new(&mut buf);

ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?;
Poll::Pending

if buf.filled().is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}

Expand All @@ -36,24 +43,43 @@ async fn send(
let stream = TcpStream::connect(&addr).await?;
let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap();

let mut stream = connector.connect(domain, stream).await?;
stream.write_all(data).await?;
stream.flush().await?;
let stream = connector.connect(domain, stream).await?;
let (mut rd, mut wd) = split(stream);
let (notify, wait) = oneshot::channel();

let j = tokio::spawn(async move {
// read to eof
//
// see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html
let mut read_task = Read1(&mut rd);
let mut notify = Some(notify);

// read once, then write
//
// this is a regression test, see https://github.com/tokio-rs/tls/issues/54
future::poll_fn(|cx| {
let ret = Pin::new(&mut read_task).poll(cx)?;
assert_eq!(ret, Poll::Pending);

notify.take().unwrap().send(()).unwrap();

Poll::Ready(Ok(())) as Poll<io::Result<_>>
}).await?;

read_task.await?;

Ok(rd) as io::Result<_>
});

wait.await.unwrap();

// sleep 1s
//
// see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html
let sleep1 = sleep(Duration::from_secs(1));
futures_util::pin_mut!(sleep1);
let mut stream = match future::select(Read1(stream), sleep1).await {
future::Either::Right((_, Read1(stream))) => stream,
future::Either::Left((Err(err), _)) => return Err(err),
future::Either::Left((Ok(_), _)) => unreachable!(),
};
wd.write_all(data).await?;
wd.flush().await?;
wd.shutdown().await?;

stream.shutdown().await?;
let rd: tokio::io::ReadHalf<_> = j.await??;

Ok(stream)
Ok(rd.unsplit(wd))
}

struct DropKill(Child);
Expand Down

0 comments on commit 46cb00f

Please sign in to comment.