From ad19cf8f2324e4b96e9449d3b4e95cc647ee9959 Mon Sep 17 00:00:00 2001 From: Pierre Krieger Date: Thu, 5 Mar 2020 12:42:15 +0100 Subject: [PATCH 1/2] Update multistream-select to stable futures --- core/Cargo.toml | 2 +- core/src/lib.rs | 2 +- core/src/upgrade/apply.rs | 14 +- misc/multistream-select/Cargo.toml | 10 +- misc/multistream-select/src/dialer_select.rs | 231 +++++++---- .../src/length_delimited.rs | 366 ++++++++++-------- misc/multistream-select/src/lib.rs | 31 +- .../multistream-select/src/listener_select.rs | 137 ++++--- misc/multistream-select/src/negotiated.rs | 315 +++++++++------ misc/multistream-select/src/protocol.rs | 113 +++--- misc/multistream-select/src/tests.rs | 248 ++++++------ 11 files changed, 834 insertions(+), 635 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index a8b83887224..edb7998735e 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -15,7 +15,7 @@ bs58 = "0.3.0" ed25519-dalek = "1.0.0-pre.3" either = "1.5" fnv = "1.0" -futures = { version = "0.3.1", features = ["compat", "io-compat", "executor", "thread-pool"] } +futures = { version = "0.3.1", features = ["executor", "thread-pool"] } futures-timer = "3" lazy_static = "1.2" libsecp256k1 = { version = "0.3.1", optional = true } diff --git a/core/src/lib.rs b/core/src/lib.rs index aa851ee05c1..a4e66486b41 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -41,7 +41,7 @@ mod keys_proto { /// Multi-address re-export. pub use multiaddr; -pub type Negotiated = futures::compat::Compat01As03>>; +pub type Negotiated = multistream_select::Negotiated; mod peer_id; mod translation; diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index 219766831bd..f3bee044379 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -20,7 +20,7 @@ use crate::{ConnectedPoint, Negotiated}; use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName}; -use futures::{future::Either, prelude::*, compat::Compat, compat::Compat01As03, compat::Future01CompatExt}; +use futures::{future::Either, prelude::*}; use log::debug; use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture}; use std::{iter, mem, pin::Pin, task::Context, task::Poll}; @@ -48,7 +48,7 @@ where U: InboundUpgrade>, { let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); - let future = multistream_select::listener_select_proto(Compat::new(conn), iter).compat(); + let future = multistream_select::listener_select_proto(conn, iter); InboundUpgradeApply { inner: InboundUpgradeApplyState::Init { future, upgrade: up } } @@ -61,7 +61,7 @@ where U: OutboundUpgrade> { let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); - let future = multistream_select::dialer_select_proto(Compat::new(conn), iter, v).compat(); + let future = multistream_select::dialer_select_proto(conn, iter, v); OutboundUpgradeApply { inner: OutboundUpgradeApplyState::Init { future, upgrade: up } } @@ -82,7 +82,7 @@ where U: InboundUpgrade>, { Init { - future: Compat01As03, NameWrap>>, + future: ListenerSelectFuture>, upgrade: U, }, Upgrade { @@ -117,7 +117,7 @@ where } }; self.inner = InboundUpgradeApplyState::Upgrade { - future: Box::pin(upgrade.upgrade_inbound(Compat01As03::new(io), info.0)) + future: Box::pin(upgrade.upgrade_inbound(io, info.0)) }; } InboundUpgradeApplyState::Upgrade { mut future } => { @@ -158,7 +158,7 @@ where U: OutboundUpgrade> { Init { - future: Compat01As03, NameWrapIter<::IntoIter>>>, + future: DialerSelectFuture::IntoIter>>, upgrade: U }, Upgrade { @@ -193,7 +193,7 @@ where } }; self.inner = OutboundUpgradeApplyState::Upgrade { - future: Box::pin(upgrade.upgrade_outbound(Compat01As03::new(connection), info.0)) + future: Box::pin(upgrade.upgrade_outbound(connection, info.0)) }; } OutboundUpgradeApplyState::Upgrade { mut future } => { diff --git a/misc/multistream-select/Cargo.toml b/misc/multistream-select/Cargo.toml index fee5cffb605..33857b11b9b 100644 --- a/misc/multistream-select/Cargo.toml +++ b/misc/multistream-select/Cargo.toml @@ -11,14 +11,14 @@ edition = "2018" [dependencies] bytes = "0.5" -futures = "0.1" +futures = "0.3" log = "0.4" +pin-project = "0.4.8" smallvec = "1.0" -tokio-io = "0.1" -unsigned-varint = "0.3" +unsigned-varint = "0.3.2" [dev-dependencies] -tokio = "0.1" -tokio-tcp = "0.1" +async-std = "1.2" quickcheck = "0.9.0" rand = "0.7.2" +rw-stream-sink = "0.2.1" diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index c17d9d80659..1f7bffda175 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -20,12 +20,11 @@ //! Protocol negotiation strategies for the peer acting as the dialer. +use crate::{Negotiated, NegotiationError}; use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version}; + use futures::{future::Either, prelude::*}; -use log::debug; -use std::{io, iter, mem, convert::TryFrom}; -use tokio_io::{AsyncRead, AsyncWrite}; -use crate::{Negotiated, NegotiationError}; +use std::{convert::TryFrom as _, io, iter, mem, pin::Pin, task::{Context, Poll}}; /// Returns a `Future` that negotiates a protocol on the given I/O stream /// for a peer acting as the _dialer_ (or _initiator_). @@ -60,9 +59,9 @@ where let iter = protocols.into_iter(); // We choose between the "serial" and "parallel" strategies based on the number of protocols. if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) { - Either::A(dialer_select_proto_serial(inner, iter, version)) + Either::Left(dialer_select_proto_serial(inner, iter, version)) } else { - Either::B(dialer_select_proto_parallel(inner, iter, version)) + Either::Right(dialer_select_proto_parallel(inner, iter, version)) } } @@ -129,6 +128,7 @@ where /// A `Future` returned by [`dialer_select_proto_serial`] which negotiates /// a protocol iteratively by considering one protocol after the other. +#[pin_project::pin_project] pub struct DialerSelectSeq where R: AsyncRead + AsyncWrite, @@ -155,83 +155,107 @@ where impl Future for DialerSelectSeq where - R: AsyncRead + AsyncWrite, + // The Unpin bound here is required because we produce a `Negotiated` as the output. + // It also makes the implementation considerably easier to write. + R: AsyncRead + AsyncWrite + Unpin, I: Iterator, I::Item: AsRef<[u8]> { - type Item = (I::Item, Negotiated); - type Error = NegotiationError; + type Output = Result<(I::Item, Negotiated), NegotiationError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); - fn poll(&mut self) -> Poll { loop { - match mem::replace(&mut self.state, SeqState::Done) { + match mem::replace(this.state, SeqState::Done) { SeqState::SendHeader { mut io } => { - if io.start_send(Message::Header(self.version))?.is_not_ready() { - self.state = SeqState::SendHeader { io }; - return Ok(Async::NotReady) + match Pin::new(&mut io).poll_ready(cx)? { + Poll::Ready(()) => {}, + Poll::Pending => { + *this.state = SeqState::SendHeader { io }; + return Poll::Pending + }, + } + + if let Err(err) = Pin::new(&mut io).start_send(Message::Header(*this.version)) { + return Poll::Ready(Err(From::from(err))); } - let protocol = self.protocols.next().ok_or(NegotiationError::Failed)?; - self.state = SeqState::SendProtocol { io, protocol }; + + let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; + *this.state = SeqState::SendProtocol { io, protocol }; } + SeqState::SendProtocol { mut io, protocol } => { + match Pin::new(&mut io).poll_ready(cx)? { + Poll::Ready(()) => {}, + Poll::Pending => { + *this.state = SeqState::SendProtocol { io, protocol }; + return Poll::Pending + }, + } + let p = Protocol::try_from(protocol.as_ref())?; - if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() { - self.state = SeqState::SendProtocol { io, protocol }; - return Ok(Async::NotReady) + if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) { + return Poll::Ready(Err(From::from(err))); } - debug!("Dialer: Proposed protocol: {}", p); - if self.protocols.peek().is_some() { - self.state = SeqState::FlushProtocol { io, protocol } + log::debug!("Dialer: Proposed protocol: {}", p); + + if this.protocols.peek().is_some() { + *this.state = SeqState::FlushProtocol { io, protocol } } else { - match self.version { - Version::V1 => self.state = SeqState::FlushProtocol { io, protocol }, + match this.version { + Version::V1 => *this.state = SeqState::FlushProtocol { io, protocol }, Version::V1Lazy => { - debug!("Dialer: Expecting proposed protocol: {}", p); - let io = Negotiated::expecting(io.into_reader(), p, self.version); - return Ok(Async::Ready((protocol, io))) + log::debug!("Dialer: Expecting proposed protocol: {}", p); + let io = Negotiated::expecting(io.into_reader(), p, *this.version); + return Poll::Ready(Ok((protocol, io))) } } } } + SeqState::FlushProtocol { mut io, protocol } => { - if io.poll_complete()?.is_not_ready() { - self.state = SeqState::FlushProtocol { io, protocol }; - return Ok(Async::NotReady) + match Pin::new(&mut io).poll_flush(cx)? { + Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol }, + Poll::Pending => { + *this.state = SeqState::FlushProtocol { io, protocol }; + return Poll::Pending + }, } - self.state = SeqState::AwaitProtocol { io, protocol } } + SeqState::AwaitProtocol { mut io, protocol } => { - let msg = match io.poll()? { - Async::NotReady => { - self.state = SeqState::AwaitProtocol { io, protocol }; - return Ok(Async::NotReady) + let msg = match Pin::new(&mut io).poll_next(cx)? { + Poll::Ready(Some(msg)) => msg, + Poll::Pending => { + *this.state = SeqState::AwaitProtocol { io, protocol }; + return Poll::Pending } - Async::Ready(None) => - return Err(NegotiationError::from( - io::Error::from(io::ErrorKind::UnexpectedEof))), - Async::Ready(Some(msg)) => msg, + Poll::Ready(None) => + return Poll::Ready(Err(NegotiationError::from( + io::Error::from(io::ErrorKind::UnexpectedEof)))), }; match msg { - Message::Header(v) if v == self.version => { - self.state = SeqState::AwaitProtocol { io, protocol }; + Message::Header(v) if v == *this.version => { + *this.state = SeqState::AwaitProtocol { io, protocol }; } Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => { - debug!("Dialer: Received confirmation for protocol: {}", p); + log::debug!("Dialer: Received confirmation for protocol: {}", p); let (io, remaining) = io.into_inner(); let io = Negotiated::completed(io, remaining); - return Ok(Async::Ready((protocol, io))) + return Poll::Ready(Ok((protocol, io))); } Message::NotAvailable => { - debug!("Dialer: Received rejection of protocol: {}", + log::debug!("Dialer: Received rejection of protocol: {}", String::from_utf8_lossy(protocol.as_ref())); - let protocol = self.protocols.next() - .ok_or(NegotiationError::Failed)?; - self.state = SeqState::SendProtocol { io, protocol } + let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; + *this.state = SeqState::SendProtocol { io, protocol } } - _ => return Err(ProtocolError::InvalidMessage.into()) + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), } } + SeqState::Done => panic!("SeqState::poll called after completion") } } @@ -241,6 +265,7 @@ where /// A `Future` returned by [`dialer_select_proto_parallel`] which negotiates /// a protocol selectively by considering all supported protocols of the remote /// "in parallel". +#[pin_project::pin_project] pub struct DialerSelectPar where R: AsyncRead + AsyncWrite, @@ -267,76 +292,110 @@ where impl Future for DialerSelectPar where - R: AsyncRead + AsyncWrite, + // The Unpin bound here is required because we produce a `Negotiated` as the output. + // It also makes the implementation considerably easier to write. + R: AsyncRead + AsyncWrite + Unpin, I: Iterator, I::Item: AsRef<[u8]> { - type Item = (I::Item, Negotiated); - type Error = NegotiationError; + type Output = Result<(I::Item, Negotiated), NegotiationError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); - fn poll(&mut self) -> Poll { loop { - match mem::replace(&mut self.state, ParState::Done) { + match mem::replace(this.state, ParState::Done) { ParState::SendHeader { mut io } => { - if io.start_send(Message::Header(self.version))?.is_not_ready() { - self.state = ParState::SendHeader { io }; - return Ok(Async::NotReady) + match Pin::new(&mut io).poll_ready(cx)? { + Poll::Ready(()) => {}, + Poll::Pending => { + *this.state = ParState::SendHeader { io }; + return Poll::Pending + }, + } + + if let Err(err) = Pin::new(&mut io).start_send(Message::Header(*this.version)) { + return Poll::Ready(Err(From::from(err))); } - self.state = ParState::SendProtocolsRequest { io }; + + *this.state = ParState::SendProtocolsRequest { io }; } + ParState::SendProtocolsRequest { mut io } => { - if io.start_send(Message::ListProtocols)?.is_not_ready() { - self.state = ParState::SendProtocolsRequest { io }; - return Ok(Async::NotReady) + match Pin::new(&mut io).poll_ready(cx)? { + Poll::Ready(()) => {}, + Poll::Pending => { + *this.state = ParState::SendProtocolsRequest { io }; + return Poll::Pending + }, + } + + if let Err(err) = Pin::new(&mut io).start_send(Message::ListProtocols) { + return Poll::Ready(Err(From::from(err))); } - debug!("Dialer: Requested supported protocols."); - self.state = ParState::Flush { io } + + log::debug!("Dialer: Requested supported protocols."); + *this.state = ParState::Flush { io } } + ParState::Flush { mut io } => { - if io.poll_complete()?.is_not_ready() { - self.state = ParState::Flush { io }; - return Ok(Async::NotReady) + match Pin::new(&mut io).poll_flush(cx)? { + Poll::Ready(()) => *this.state = ParState::RecvProtocols { io }, + Poll::Pending => { + *this.state = ParState::Flush { io }; + return Poll::Pending + }, } - self.state = ParState::RecvProtocols { io } } + ParState::RecvProtocols { mut io } => { - let msg = match io.poll()? { - Async::NotReady => { - self.state = ParState::RecvProtocols { io }; - return Ok(Async::NotReady) + let msg = match Pin::new(&mut io).poll_next(cx)? { + Poll::Ready(Some(msg)) => msg, + Poll::Pending => { + *this.state = ParState::RecvProtocols { io }; + return Poll::Pending } - Async::Ready(None) => - return Err(NegotiationError::from( - io::Error::from(io::ErrorKind::UnexpectedEof))), - Async::Ready(Some(msg)) => msg, + Poll::Ready(None) => + return Poll::Ready(Err(NegotiationError::from( + io::Error::from(io::ErrorKind::UnexpectedEof)))), }; match &msg { - Message::Header(v) if v == &self.version => { - self.state = ParState::RecvProtocols { io } + Message::Header(v) if v == this.version => { + *this.state = ParState::RecvProtocols { io } } Message::Protocols(supported) => { - let protocol = self.protocols.by_ref() + let protocol = this.protocols.by_ref() .find(|p| supported.iter().any(|s| s.as_ref() == p.as_ref())) .ok_or(NegotiationError::Failed)?; - debug!("Dialer: Found supported protocol: {}", + log::debug!("Dialer: Found supported protocol: {}", String::from_utf8_lossy(protocol.as_ref())); - self.state = ParState::SendProtocol { io, protocol }; + *this.state = ParState::SendProtocol { io, protocol }; } - _ => return Err(ProtocolError::InvalidMessage.into()) + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), } } + ParState::SendProtocol { mut io, protocol } => { + match Pin::new(&mut io).poll_ready(cx)? { + Poll::Ready(()) => {}, + Poll::Pending => { + *this.state = ParState::SendProtocol { io, protocol }; + return Poll::Pending + }, + } + let p = Protocol::try_from(protocol.as_ref())?; - if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() { - self.state = ParState::SendProtocol { io, protocol }; - return Ok(Async::NotReady) + if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) { + return Poll::Ready(Err(From::from(err))); } - debug!("Dialer: Expecting proposed protocol: {}", p); - let io = Negotiated::expecting(io.into_reader(), p, self.version); - return Ok(Async::Ready((protocol, io))) + log::debug!("Dialer: Expecting proposed protocol: {}", p); + + let io = Negotiated::expecting(io.into_reader(), p, *this.version); + return Poll::Ready(Ok((protocol, io))) } + ParState::Done => panic!("ParState::poll called after completion") } } diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 3efd700e6e8..42e82a85040 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -18,11 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::{Bytes, BytesMut, Buf, BufMut}; -use futures::{try_ready, Async, Poll, Sink, StartSend, Stream, AsyncSink}; -use std::{io, u16}; -use tokio_io::{AsyncRead, AsyncWrite}; -use unsigned_varint as uvi; +use bytes::{Bytes, BytesMut, Buf as _, BufMut as _}; +use futures::{prelude::*, io::IoSlice}; +use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16}; const MAX_LEN_BYTES: u16 = 2; const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; @@ -34,9 +32,11 @@ const DEFAULT_BUFFER_SIZE: usize = 64; /// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint /// frame length). Frames mostly consist in a short protocol name, which is highly /// unlikely to be more than 16KiB long. +#[pin_project::pin_project] #[derive(Debug)] pub struct LengthDelimited { /// The inner I/O resource. + #[pin] inner: R, /// Read buffer for a single incoming unsigned-varint length-delimited frame. read_buffer: BytesMut, @@ -76,20 +76,7 @@ impl LengthDelimited { } } - /// Returns a reference to the underlying I/O stream. - pub fn inner_ref(&self) -> &R { - &self.inner - } - - /// Returns a mutable reference to the underlying I/O stream. - /// - /// > **Note**: Care should be taken to not tamper with the underlying stream of data - /// > coming in, as it may corrupt the stream of frames. - pub fn inner_mut(&mut self) -> &mut R { - &mut self.inner - } - - /// Drops the `LengthDelimited` resource, yielding the underlying I/O stream + /// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream /// together with the remaining write buffer containing the uvi-framed data /// that has not yet been written to the underlying I/O stream. /// @@ -107,7 +94,7 @@ impl LengthDelimited { (self.inner, self.write_buffer) } - /// Converts the `LengthDelimited` into a `LengthDelimitedReader`, dropping the + /// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the /// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying /// I/O stream. /// @@ -121,25 +108,29 @@ impl LengthDelimited { /// Writes all buffered frame data to the underlying I/O stream, /// _without flushing it_. /// - /// After this method returns `Async::Ready`, the write buffer of frames + /// After this method returns `Poll::Ready`, the write buffer of frames /// submitted to the `Sink` is guaranteed to be empty. - pub fn poll_write_buffer(&mut self) -> Poll<(), io::Error> + pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context) + -> Poll> where R: AsyncWrite { - while !self.write_buffer.is_empty() { - let n = try_ready!(self.inner.poll_write(&self.write_buffer)); - - if n == 0 { - return Err(io::Error::new( - io::ErrorKind::WriteZero, - "Failed to write buffered frame.")) + let mut this = self.project(); + + while !this.write_buffer.is_empty() { + match this.inner.as_mut().poll_write(cx, &this.write_buffer) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(0)) => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "Failed to write buffered frame."))) + } + Poll::Ready(Ok(n)) => this.write_buffer.advance(n), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } - - self.write_buffer.advance(n); } - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } } @@ -147,72 +138,67 @@ impl Stream for LengthDelimited where R: AsyncRead { - type Item = Bytes; - type Error = io::Error; + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let mut this = self.project(); - fn poll(&mut self) -> Poll, Self::Error> { loop { - match &mut self.read_state { + match this.read_state { ReadState::ReadLength { buf, pos } => { - match self.inner.read(&mut buf[*pos .. *pos + 1]) { - Ok(0) => { + match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) { + Poll::Ready(Ok(0)) => { if *pos == 0 { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } else { - return Err(io::ErrorKind::UnexpectedEof.into()); + return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))); } } - Ok(n) => { + Poll::Ready(Ok(n)) => { debug_assert_eq!(n, 1); *pos += n; } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - return Ok(Async::NotReady); - } - Err(err) => { - return Err(err); - } + Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), + Poll::Pending => return Poll::Pending, }; if (buf[*pos - 1] & 0x80) == 0 { // MSB is not set, indicating the end of the length prefix. - let (len, _) = uvi::decode::u16(buf).map_err(|e| { - log::debug!("invalid length prefix: {}", e); - io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") - })?; + let (len, _) = unsigned_varint::decode::u16(buf) + .map_err(|e| { + log::debug!("invalid length prefix: {}", e); + io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") + })?; if len >= 1 { - self.read_state = ReadState::ReadData { len, pos: 0 }; - self.read_buffer.resize(len as usize, 0); + *this.read_state = ReadState::ReadData { len, pos: 0 }; + this.read_buffer.resize(len as usize, 0); } else { debug_assert_eq!(len, 0); - self.read_state = ReadState::default(); - return Ok(Async::Ready(Some(Bytes::new()))); + *this.read_state = ReadState::default(); + return Poll::Ready(Some(Ok(Bytes::new()))); } } else if *pos == MAX_LEN_BYTES as usize { // MSB signals more length bytes but we have already read the maximum. // See the module documentation about the max frame len. - return Err(io::Error::new( + return Poll::Ready(Some(Err(io::Error::new( io::ErrorKind::InvalidData, - "Maximum frame length exceeded")); + "Maximum frame length exceeded")))); } } ReadState::ReadData { len, pos } => { - match self.inner.read(&mut self.read_buffer[*pos..]) { - Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()), - Ok(n) => *pos += n, - Err(err) => - if err.kind() == io::ErrorKind::WouldBlock { - return Ok(Async::NotReady) - } else { - return Err(err) - } + match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) { + Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))), + Poll::Ready(Ok(n)) => *pos += n, + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), }; + if *pos == *len as usize { // Finished reading the frame. - let frame = self.read_buffer.split_off(0).freeze(); - self.read_state = ReadState::default(); - return Ok(Async::Ready(Some(frame))); + let frame = this.read_buffer.split_off(0).freeze(); + *this.read_state = ReadState::default(); + return Poll::Ready(Some(Ok(frame))); } } } @@ -220,58 +206,87 @@ where } } -impl Sink for LengthDelimited +impl Sink for LengthDelimited where R: AsyncWrite, { - type SinkItem = Bytes; - type SinkError = io::Error; + type Error = io::Error; - fn start_send(&mut self, msg: Self::SinkItem) -> StartSend { + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { // Use the maximum frame length also as a (soft) upper limit // for the entire write buffer. The actual (hard) limit is thus // implied to be roughly 2 * MAX_FRAME_SIZE. - if self.write_buffer.len() >= MAX_FRAME_SIZE as usize { - self.poll_complete()?; - if self.write_buffer.len() >= MAX_FRAME_SIZE as usize { - return Ok(AsyncSink::NotReady(msg)) + if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize { + match self.as_mut().poll_write_buffer(cx) { + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, } - } - let len = msg.len() as u16; - if len > MAX_FRAME_SIZE { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Maximum frame size exceeded.")) + debug_assert!(self.as_mut().project().write_buffer.is_empty()); } - let mut uvi_buf = uvi::encode::u16_buffer(); - let uvi_len = uvi::encode::u16(len, &mut uvi_buf); - self.write_buffer.reserve(len as usize + uvi_len.len()); - self.write_buffer.put(uvi_len); - self.write_buffer.put(msg); + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + let this = self.project(); - Ok(AsyncSink::Ready) + let len = match u16::try_from(item.len()) { + Ok(len) if len <= MAX_FRAME_SIZE => len, + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Maximum frame size exceeded.")) + } + }; + + let mut uvi_buf = unsigned_varint::encode::u16_buffer(); + let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf); + this.write_buffer.reserve(len as usize + uvi_len.len()); + this.write_buffer.put(uvi_len); + this.write_buffer.put(item); + + Ok(()) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { // Write all buffered frame data to the underlying I/O stream. - try_ready!(self.poll_write_buffer()); + match LengthDelimited::poll_write_buffer(self.as_mut(), cx) { + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + } + + let this = self.project(); + debug_assert!(this.write_buffer.is_empty()); + // Flush the underlying I/O stream. - try_ready!(self.inner.poll_flush()); - return Ok(Async::Ready(())); + this.inner.poll_flush(cx) } - fn close(&mut self) -> Poll<(), Self::SinkError> { - try_ready!(self.poll_complete()); - Ok(self.inner.shutdown()?) + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // Write all buffered frame data to the underlying I/O stream. + match LengthDelimited::poll_write_buffer(self.as_mut(), cx) { + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + } + + let this = self.project(); + debug_assert!(this.write_buffer.is_empty()); + + // Close the underlying I/O stream. + this.inner.poll_close(cx) } } /// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited /// frames on an underlying I/O resource combined with direct `AsyncWrite` access. +#[pin_project::pin_project] #[derive(Debug)] pub struct LengthDelimitedReader { + #[pin] inner: LengthDelimited } @@ -291,75 +306,79 @@ impl LengthDelimitedReader { pub fn into_inner(self) -> (R, BytesMut) { self.inner.into_inner() } - - /// Returns a reference to the underlying I/O stream. - pub fn inner_ref(&self) -> &R { - self.inner.inner_ref() - } - - /// Returns a mutable reference to the underlying I/O stream. - /// - /// > **Note**: Care should be taken to not tamper with the underlying stream of data - /// > coming in, as it may corrupt the stream of frames. - pub fn inner_mut(&mut self) -> &mut R { - self.inner.inner_mut() - } } impl Stream for LengthDelimitedReader where R: AsyncRead { - type Item = Bytes; - type Error = io::Error; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - self.inner.poll() + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().inner.poll_next(cx) } } -impl io::Write for LengthDelimitedReader +impl AsyncWrite for LengthDelimitedReader where R: AsyncWrite { - fn write(&mut self, buf: &[u8]) -> io::Result { - while !self.inner.write_buffer.is_empty() { - if self.inner.poll_write_buffer()?.is_not_ready() { - return Err(io::ErrorKind::WouldBlock.into()) - } + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) + -> Poll> + { + // `this` here designates the `LengthDelimited`. + let mut this = self.project().inner; + + // We need to flush any data previously written with the `LengthDelimited`. + match LengthDelimited::poll_write_buffer(this.as_mut(), cx) { + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, } - self.inner_mut().write(buf) + debug_assert!(this.write_buffer.is_empty()); + + this.project().inner.poll_write(cx, buf) } - fn flush(&mut self) -> io::Result<()> { - match self.inner.poll_complete()? { - Async::Ready(()) => Ok(()), - Async::NotReady => Err(io::ErrorKind::WouldBlock.into()) - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().inner.poll_flush(cx) } -} -impl AsyncWrite for LengthDelimitedReader -where - R: AsyncWrite -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { - try_ready!(self.inner.poll_complete()); - self.inner_mut().shutdown() + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().inner.poll_close(cx) + } + + fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice]) + -> Poll> + { + // `this` here designates the `LengthDelimited`. + let mut this = self.project().inner; + + // We need to flush any data previously written with the `LengthDelimited`. + match LengthDelimited::poll_write_buffer(this.as_mut(), cx) { + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + } + debug_assert!(this.write_buffer.is_empty()); + + this.project().inner.poll_write_vectored(cx, bufs) } } #[cfg(test)] mod tests { - use futures::{Future, Stream}; use crate::length_delimited::LengthDelimited; - use std::io::{Cursor, ErrorKind}; + use async_std::net::{TcpListener, TcpStream}; + use futures::{prelude::*, io::Cursor}; + use quickcheck::*; + use std::io::ErrorKind; #[test] fn basic_read() { let data = vec![6, 9, 8, 7, 6, 5, 4]; let framed = LengthDelimited::new(Cursor::new(data)); - let recved = framed.collect().wait().unwrap(); + let recved = futures::executor::block_on(framed.try_collect::>()).unwrap(); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]); } @@ -367,7 +386,7 @@ mod tests { fn basic_read_two() { let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7]; let framed = LengthDelimited::new(Cursor::new(data)); - let recved = framed.collect().wait().unwrap(); + let recved = futures::executor::block_on(framed.try_collect::>()).unwrap(); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]); } @@ -378,13 +397,10 @@ mod tests { let frame = (0..len).map(|n| (n & 0xff) as u8).collect::>(); let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8]; data.extend(frame.clone().into_iter()); - let framed = LengthDelimited::new(Cursor::new(data)); - let recved = framed - .into_future() - .map(|(m, _)| m) - .map_err(|_| ()) - .wait() - .unwrap(); + let mut framed = LengthDelimited::new(Cursor::new(data)); + let recved = futures::executor::block_on(async move { + framed.next().await + }).unwrap(); assert_eq!(recved.unwrap(), frame); } @@ -392,12 +408,10 @@ mod tests { fn packet_len_too_long() { let mut data = vec![0x81, 0x81, 0x1]; data.extend((0..16513).map(|_| 0)); - let framed = LengthDelimited::new(Cursor::new(data)); - let recved = framed - .into_future() - .map(|(m, _)| m) - .map_err(|(err, _)| err) - .wait(); + let mut framed = LengthDelimited::new(Cursor::new(data)); + let recved = futures::executor::block_on(async move { + framed.next().await.unwrap() + }); if let Err(io_err) = recved { assert_eq!(io_err.kind(), ErrorKind::InvalidData) @@ -410,7 +424,7 @@ mod tests { fn empty_frames() { let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7]; let framed = LengthDelimited::new(Cursor::new(data)); - let recved = framed.collect().wait().unwrap(); + let recved = futures::executor::block_on(framed.try_collect::>()).unwrap(); assert_eq!( recved, vec![ @@ -427,7 +441,7 @@ mod tests { fn unexpected_eof_in_len() { let data = vec![0x89]; let framed = LengthDelimited::new(Cursor::new(data)); - let recved = framed.collect().wait(); + let recved = futures::executor::block_on(framed.try_collect::>()); if let Err(io_err) = recved { assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) } else { @@ -439,7 +453,7 @@ mod tests { fn unexpected_eof_in_data() { let data = vec![5]; let framed = LengthDelimited::new(Cursor::new(data)); - let recved = framed.collect().wait(); + let recved = futures::executor::block_on(framed.try_collect::>()); if let Err(io_err) = recved { assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) } else { @@ -451,12 +465,54 @@ mod tests { fn unexpected_eof_in_data2() { let data = vec![5, 9, 8, 7]; let framed = LengthDelimited::new(Cursor::new(data)); - let recved = framed.collect().wait(); + let recved = futures::executor::block_on(framed.try_collect::>()); if let Err(io_err) = recved { assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) } else { panic!() } } -} + #[test] + fn writing_reading() { + fn prop(frames: Vec>) -> TestResult { + async_std::task::block_on(async move { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener_addr = listener.local_addr().unwrap(); + + let expected_frames = frames.clone(); + let server = async_std::task::spawn(async move { + let socket = listener.accept().await.unwrap().0; + let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket)); + + let mut buf = vec![0u8; 0]; + for expected in expected_frames { + if expected.is_empty() { + continue; + } + if buf.len() < expected.len() { + buf.resize(expected.len(), 0); + } + let n = connec.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..n], &expected[..]); + } + }); + + let client = async_std::task::spawn(async move { + let socket = TcpStream::connect(&listener_addr).await.unwrap(); + let mut connec = LengthDelimited::new(socket); + for frame in frames { + connec.send(From::from(frame)).await.unwrap(); + } + }); + + server.await; + client.await; + }); + + TestResult::passed() + } + + quickcheck(prop as fn(_) -> _) + } +} diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index 6ab6eabedbf..abb088be610 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -77,26 +77,19 @@ //! //! ```no_run //! # fn main() { -//! use bytes::Bytes; +//! use async_std::net::TcpStream; //! use multistream_select::{dialer_select_proto, Version}; -//! use futures::{Future, Sink, Stream}; -//! use tokio_tcp::TcpStream; -//! use tokio::runtime::current_thread::Runtime; -//! -//! #[derive(Debug, Copy, Clone)] -//! enum MyProto { Echo, Hello } -//! -//! let client = TcpStream::connect(&"127.0.0.1:10333".parse().unwrap()) -//! .from_err() -//! .and_then(move |io| { -//! let protos = vec![b"/echo/1.0.0", b"/echo/2.5.0"]; -//! dialer_select_proto(io, protos, Version::V1) -//! }) -//! .map(|(protocol, _io)| protocol); -//! -//! let mut rt = Runtime::new().unwrap(); -//! let protocol = rt.block_on(client).expect("failed to find a protocol"); -//! println!("Negotiated protocol: {:?}", protocol); +//! use futures::prelude::*; +//! +//! async_std::task::block_on(async move { +//! let socket = TcpStream::connect("127.0.0.1:10333").await.unwrap(); +//! +//! let protos = vec![b"/echo/1.0.0", b"/echo/2.5.0"]; +//! let (protocol, _io) = dialer_select_proto(socket, protos, Version::V1).await.unwrap(); +//! +//! println!("Negotiated protocol: {:?}", protocol); +//! // You can now use `_io` to communicate with the remote. +//! }); //! # } //! ``` //! diff --git a/misc/multistream-select/src/listener_select.rs b/misc/multistream-select/src/listener_select.rs index f6a39bfb0f7..93e8ec61b4e 100644 --- a/misc/multistream-select/src/listener_select.rs +++ b/misc/multistream-select/src/listener_select.rs @@ -21,13 +21,12 @@ //! Protocol negotiation strategies for the peer acting as the listener //! in a multistream-select protocol negotiation. -use futures::prelude::*; +use crate::{Negotiated, NegotiationError}; use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version}; -use log::{debug, warn}; + +use futures::prelude::*; use smallvec::SmallVec; -use std::{io, iter::FromIterator, mem, convert::TryFrom}; -use tokio_io::{AsyncRead, AsyncWrite}; -use crate::{Negotiated, NegotiationError}; +use std::{convert::TryFrom as _, io, iter::FromIterator, mem, pin::Pin, task::{Context, Poll}}; /// Returns a `Future` that negotiates a protocol on the given I/O stream /// for a peer acting as the _listener_ (or _responder_). @@ -49,7 +48,7 @@ where match Protocol::try_from(n.as_ref()) { Ok(p) => Some((n, p)), Err(e) => { - warn!("Listener: Ignoring invalid protocol: {} due to {}", + log::warn!("Listener: Ignoring invalid protocol: {} due to {}", String::from_utf8_lossy(n.as_ref()), e); None } @@ -64,6 +63,7 @@ where /// The `Future` returned by [`listener_select_proto`] that performs a /// multistream-select protocol negotiation on an underlying I/O stream. +#[pin_project::pin_project] pub struct ListenerSelectFuture where R: AsyncRead + AsyncWrite, @@ -94,64 +94,80 @@ where impl Future for ListenerSelectFuture where - R: AsyncRead + AsyncWrite, + // The Unpin bound here is required because we produce a `Negotiated` as the output. + // It also makes the implementation considerably easier to write. + R: AsyncRead + AsyncWrite + Unpin, N: AsRef<[u8]> + Clone { - type Item = (N, Negotiated); - type Error = NegotiationError; + type Output = Result<(N, Negotiated), NegotiationError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); - fn poll(&mut self) -> Poll { loop { - match mem::replace(&mut self.state, State::Done) { + match mem::replace(this.state, State::Done) { State::RecvHeader { mut io } => { - match io.poll()? { - Async::Ready(Some(Message::Header(version))) => { - self.state = State::SendHeader { io, version } - } - Async::Ready(Some(_)) => { - return Err(ProtocolError::InvalidMessage.into()) + match io.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(Message::Header(version)))) => { + *this.state = State::SendHeader { io, version } } - Async::Ready(None) => - return Err(NegotiationError::from( + Poll::Ready(Some(Ok(_))) => { + return Poll::Ready(Err(ProtocolError::InvalidMessage.into())) + }, + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))), + Poll::Ready(None) => + return Poll::Ready(Err(NegotiationError::from( ProtocolError::IoError( - io::ErrorKind::UnexpectedEof.into()))), - Async::NotReady => { - self.state = State::RecvHeader { io }; - return Ok(Async::NotReady) + io::ErrorKind::UnexpectedEof.into())))), + Poll::Pending => { + *this.state = State::RecvHeader { io }; + return Poll::Pending } } } + State::SendHeader { mut io, version } => { - if io.start_send(Message::Header(version))?.is_not_ready() { - return Ok(Async::NotReady) + match Pin::new(&mut io).poll_ready(cx) { + Poll::Pending => { + *this.state = State::SendHeader { io, version }; + return Poll::Pending + }, + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), + } + + if let Err(err) = Pin::new(&mut io).start_send(Message::Header(version)) { + return Poll::Ready(Err(From::from(err))); } - self.state = match version { + + *this.state = match version { Version::V1 => State::Flush { io }, Version::V1Lazy => State::RecvMessage { io }, } } + State::RecvMessage { mut io } => { - let msg = match io.poll() { - Ok(Async::Ready(Some(msg))) => msg, - Ok(Async::Ready(None)) => - return Err(NegotiationError::from( + let msg = match Pin::new(&mut io).poll_next(cx) { + Poll::Ready(Some(Ok(msg))) => msg, + Poll::Ready(None) => + return Poll::Ready(Err(NegotiationError::from( ProtocolError::IoError( - io::ErrorKind::UnexpectedEof.into()))), - Ok(Async::NotReady) => { - self.state = State::RecvMessage { io }; - return Ok(Async::NotReady) + io::ErrorKind::UnexpectedEof.into())))), + Poll::Pending => { + *this.state = State::RecvMessage { io }; + return Poll::Pending; } - Err(e) => return Err(e.into()) + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))), }; match msg { Message::ListProtocols => { - let supported = self.protocols.iter().map(|(_,p)| p).cloned().collect(); + let supported = this.protocols.iter().map(|(_,p)| p).cloned().collect(); let message = Message::Protocols(supported); - self.state = State::SendMessage { io, message, protocol: None } + *this.state = State::SendMessage { io, message, protocol: None } } Message::Protocol(p) => { - let protocol = self.protocols.iter().find_map(|(name, proto)| { + let protocol = this.protocols.iter().find_map(|(name, proto)| { if &p == proto { Some(name.clone()) } else { @@ -160,45 +176,60 @@ where }); let message = if protocol.is_some() { - debug!("Listener: confirming protocol: {}", p); + log::debug!("Listener: confirming protocol: {}", p); Message::Protocol(p.clone()) } else { - debug!("Listener: rejecting protocol: {}", + log::debug!("Listener: rejecting protocol: {}", String::from_utf8_lossy(p.as_ref())); Message::NotAvailable }; - self.state = State::SendMessage { io, message, protocol }; + *this.state = State::SendMessage { io, message, protocol }; } - _ => return Err(ProtocolError::InvalidMessage.into()) + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())) } } + State::SendMessage { mut io, message, protocol } => { - if let AsyncSink::NotReady(message) = io.start_send(message)? { - self.state = State::SendMessage { io, message, protocol }; - return Ok(Async::NotReady) - }; + match Pin::new(&mut io).poll_ready(cx) { + Poll::Pending => { + *this.state = State::SendMessage { io, message, protocol }; + return Poll::Pending + }, + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), + } + + if let Err(err) = Pin::new(&mut io).start_send(message) { + return Poll::Ready(Err(From::from(err))); + } + // If a protocol has been selected, finish negotiation. // Otherwise flush the sink and expect to receive another // message. - self.state = match protocol { + *this.state = match protocol { Some(protocol) => { - debug!("Listener: sent confirmed protocol: {}", + log::debug!("Listener: sent confirmed protocol: {}", String::from_utf8_lossy(protocol.as_ref())); let (io, remaining) = io.into_inner(); let io = Negotiated::completed(io, remaining); - return Ok(Async::Ready((protocol, io))) + return Poll::Ready(Ok((protocol, io))); } None => State::Flush { io } }; } + State::Flush { mut io } => { - if io.poll_complete()?.is_not_ready() { - self.state = State::Flush { io }; - return Ok(Async::NotReady) + match Pin::new(&mut io).poll_flush(cx) { + Poll::Pending => { + *this.state = State::Flush { io }; + return Poll::Pending + }, + Poll::Ready(Ok(())) => *this.state = State::RecvMessage { io }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } - self.state = State::RecvMessage { io } } + State::Done => panic!("State::poll called after completion") } } diff --git a/misc/multistream-select/src/negotiated.rs b/misc/multistream-select/src/negotiated.rs index ff50d3c71ff..f5b368c63fb 100644 --- a/misc/multistream-select/src/negotiated.rs +++ b/misc/multistream-select/src/negotiated.rs @@ -18,12 +18,12 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::{BytesMut, Buf}; use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError}; -use futures::{prelude::*, Async, try_ready}; -use log::debug; -use tokio_io::{AsyncRead, AsyncWrite}; -use std::{mem, io, fmt, error::Error}; + +use bytes::{BytesMut, Buf}; +use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready}; +use pin_project::{pin_project, project}; +use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}}; /// An I/O stream that has settled on an (application-layer) protocol to use. /// @@ -36,28 +36,40 @@ use std::{mem, io, fmt, error::Error}; /// /// Reading from a `Negotiated` I/O stream that still has pending negotiation /// protocol data to send implicitly triggers flushing of all yet unsent data. +#[pin_project] #[derive(Debug)] pub struct Negotiated { + #[pin] state: State } /// A `Future` that waits on the completion of protocol negotiation. #[derive(Debug)] pub struct NegotiatedComplete { - inner: Option> + inner: Option>, } -impl Future for NegotiatedComplete { - type Item = Negotiated; - type Error = NegotiationError; +impl Future for NegotiatedComplete +where + // `Unpin` is required not because of implementation details but because we produce the + // `Negotiated` as the output of the future. + TInner: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result, NegotiationError>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut io = self.inner.take().expect("NegotiatedFuture called after completion."); - if io.poll()?.is_not_ready() { - self.inner = Some(io); - return Ok(Async::NotReady) + match Negotiated::poll(Pin::new(&mut io), cx) { + Poll::Pending => { + self.inner = Some(io); + return Poll::Pending + }, + Poll::Ready(Ok(())) => Poll::Ready(Ok(io)), + Poll::Ready(Err(err)) => { + self.inner = Some(io); + return Poll::Ready(Err(err)); + } } - return Ok(Async::Ready(io)) } } @@ -75,66 +87,67 @@ impl Negotiated { } /// Polls the `Negotiated` for completion. - fn poll(&mut self) -> Poll<(), NegotiationError> + #[project] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> where - TInner: AsyncRead + AsyncWrite + TInner: AsyncRead + AsyncWrite + Unpin { // Flush any pending negotiation data. - match self.poll_flush() { - Ok(Async::Ready(())) => {}, - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(e) => { + match self.as_mut().poll_flush(cx) { + Poll::Ready(Ok(())) => {}, + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => { // If the remote closed the stream, it is important to still // continue reading the data that was sent, if any. if e.kind() != io::ErrorKind::WriteZero { - return Err(e.into()) + return Poll::Ready(Err(e.into())) } } } - if let State::Completed { remaining, .. } = &mut self.state { - let _ = remaining.split_to(remaining.len()); // Drop remaining data flushed above. - return Ok(Async::Ready(())) + let mut this = self.project(); + + #[project] + match this.state.as_mut().project() { + State::Completed { remaining, .. } => { + debug_assert!(remaining.is_empty()); + return Poll::Ready(Ok(())) + } + _ => {} } // Read outstanding protocol negotiation messages. loop { - match mem::replace(&mut self.state, State::Invalid) { + match mem::replace(&mut *this.state, State::Invalid) { State::Expecting { mut io, protocol, version } => { - let msg = match io.poll() { - Ok(Async::Ready(Some(msg))) => msg, - Ok(Async::NotReady) => { - self.state = State::Expecting { io, protocol, version }; - return Ok(Async::NotReady) - } - Ok(Async::Ready(None)) => { - self.state = State::Expecting { io, protocol, version }; - return Err(ProtocolError::IoError( - io::ErrorKind::UnexpectedEof.into()).into()) - } - Err(err) => { - self.state = State::Expecting { io, protocol, version }; - return Err(err.into()) + let msg = match Pin::new(&mut io).poll_next(cx)? { + Poll::Ready(Some(msg)) => msg, + Poll::Pending => { + *this.state = State::Expecting { io, protocol, version }; + return Poll::Pending + }, + Poll::Ready(None) => { + return Poll::Ready(Err(ProtocolError::IoError( + io::ErrorKind::UnexpectedEof.into()).into())); } }; if let Message::Header(v) = &msg { - if v == &version { - self.state = State::Expecting { io, protocol, version }; + if *v == version { continue } } if let Message::Protocol(p) = &msg { if p.as_ref() == protocol.as_ref() { - debug!("Negotiated: Received confirmation for protocol: {}", p); + log::debug!("Negotiated: Received confirmation for protocol: {}", p); let (io, remaining) = io.into_inner(); - self.state = State::Completed { io, remaining }; - return Ok(Async::Ready(())) + *this.state = State::Completed { io, remaining }; + return Poll::Ready(Ok(())); } } - return Err(NegotiationError::Failed) + return Poll::Ready(Err(NegotiationError::Failed)); } _ => panic!("Negotiated: Invalid state") @@ -142,7 +155,7 @@ impl Negotiated { } } - /// Returns a `NegotiatedComplete` future that waits for protocol + /// Returns a [`NegotiatedComplete`] future that waits for protocol /// negotiation to complete. pub fn complete(self) -> NegotiatedComplete { NegotiatedComplete { inner: Some(self) } @@ -150,12 +163,14 @@ impl Negotiated { } /// The states of a `Negotiated` I/O stream. +#[pin_project] #[derive(Debug)] enum State { /// In this state, a `Negotiated` is still expecting to /// receive confirmation of the protocol it as settled on. Expecting { /// The underlying I/O stream. + #[pin] io: MessageReader, /// The expected protocol (i.e. name and version). protocol: Protocol, @@ -167,113 +182,157 @@ enum State { /// only be pending the sending of the final acknowledgement, /// which is prepended to / combined with the next write for /// efficiency. - Completed { io: R, remaining: BytesMut }, + Completed { #[pin] io: R, remaining: BytesMut }, /// Temporary state while moving the `io` resource from /// `Expecting` to `Completed`. Invalid, } -impl io::Read for Negotiated +impl AsyncRead for Negotiated where - R: AsyncRead + AsyncWrite + TInner: AsyncRead + AsyncWrite + Unpin { - fn read(&mut self, buf: &mut [u8]) -> io::Result { + #[project] + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) + -> Poll> + { loop { - if let State::Completed { io, remaining } = &mut self.state { - // If protocol negotiation is complete and there is no - // remaining data to be flushed, commence with reading. - if remaining.is_empty() { - return io.read(buf) - } + #[project] + match self.as_mut().project().state.project() { + State::Completed { io, remaining } => { + // If protocol negotiation is complete and there is no + // remaining data to be flushed, commence with reading. + if remaining.is_empty() { + return io.poll_read(cx, buf) + } + }, + _ => {} } // Poll the `Negotiated`, driving protocol negotiation to completion, // including flushing of any remaining data. - let result = self.poll(); - - // There is still remaining data to be sent before data relating - // to the negotiated protocol can be read. - if let Ok(Async::NotReady) = result { - return Err(io::ErrorKind::WouldBlock.into()) - } - - if let Err(err) = result { - return Err(err.into()) + match self.as_mut().poll(cx) { + Poll::Ready(Ok(())) => {}, + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } } } -} -impl AsyncRead for Negotiated -where - TInner: AsyncRead + AsyncWrite -{ - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + // TODO: implement once method is stabilized in the futures crate + /*unsafe fn initializer(&self) -> Initializer { match &self.state { - State::Completed { io, .. } => - io.prepare_uninitialized_buffer(buf), - State::Expecting { io, .. } => - io.inner_ref().prepare_uninitialized_buffer(buf), - State::Invalid => panic!("Negotiated: Invalid state") + State::Completed { io, .. } => io.initializer(), + State::Expecting { io, .. } => io.inner_ref().initializer(), + State::Invalid => panic!("Negotiated: Invalid state"), + } + }*/ + + #[project] + fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context, bufs: &mut [IoSliceMut]) + -> Poll> + { + loop { + #[project] + match self.as_mut().project().state.project() { + State::Completed { io, remaining } => { + // If protocol negotiation is complete and there is no + // remaining data to be flushed, commence with reading. + if remaining.is_empty() { + return io.poll_read_vectored(cx, bufs) + } + }, + _ => {} + } + + // Poll the `Negotiated`, driving protocol negotiation to completion, + // including flushing of any remaining data. + match self.as_mut().poll(cx) { + Poll::Ready(Ok(())) => {}, + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), + } } } } -impl io::Write for Negotiated +impl AsyncWrite for Negotiated where - TInner: AsyncWrite + TInner: AsyncWrite + AsyncRead + Unpin { - fn write(&mut self, buf: &[u8]) -> io::Result { - match &mut self.state { - State::Completed { io, ref mut remaining } => { + #[project] + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + #[project] + match self.project().state.project() { + State::Completed { mut io, remaining } => { while !remaining.is_empty() { - let n = io.write(&remaining)?; + let n = ready!(io.as_mut().poll_write(cx, &remaining)?); if n == 0 { - return Err(io::ErrorKind::WriteZero.into()) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) } remaining.advance(n); } - io.write(buf) + io.poll_write(cx, buf) }, - State::Expecting { io, .. } => io.write(buf), - State::Invalid => panic!("Negotiated: Invalid state") + State::Expecting { io, .. } => io.poll_write(cx, buf), + State::Invalid => panic!("Negotiated: Invalid state"), } } - fn flush(&mut self) -> io::Result<()> { - match &mut self.state { - State::Completed { io, ref mut remaining } => { + #[project] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + #[project] + match self.project().state.project() { + State::Completed { mut io, remaining } => { while !remaining.is_empty() { - let n = io.write(remaining)?; + let n = ready!(io.as_mut().poll_write(cx, &remaining)?); if n == 0 { - return Err(io::Error::new( - io::ErrorKind::WriteZero, - "Failed to write remaining buffer.")) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) } remaining.advance(n); } - io.flush() + io.poll_flush(cx) }, - State::Expecting { io, .. } => io.flush(), - State::Invalid => panic!("Negotiated: Invalid state") + State::Expecting { io, .. } => io.poll_flush(cx), + State::Invalid => panic!("Negotiated: Invalid state"), } } -} -impl AsyncWrite for Negotiated -where - TInner: AsyncWrite + AsyncRead -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { + #[project] + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { // Ensure all data has been flushed and expected negotiation messages // have been received. - try_ready!(self.poll().map_err(Into::::into)); + ready!(self.as_mut().poll(cx).map_err(Into::::into)?); + ready!(self.as_mut().poll_flush(cx).map_err(Into::::into)?); + // Continue with the shutdown of the underlying I/O stream. - match &mut self.state { - State::Completed { io, .. } => io.shutdown(), - State::Expecting { io, .. } => io.shutdown(), - State::Invalid => panic!("Negotiated: Invalid state") + #[project] + match self.project().state.project() { + State::Completed { io, .. } => io.poll_close(cx), + State::Expecting { io, .. } => io.poll_close(cx), + State::Invalid => panic!("Negotiated: Invalid state"), + } + } + + #[project] + fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice]) + -> Poll> + { + #[project] + match self.project().state.project() { + State::Completed { mut io, remaining } => { + while !remaining.is_empty() { + let n = ready!(io.as_mut().poll_write(cx, &remaining)?); + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) + } + remaining.advance(n); + } + io.poll_write_vectored(cx, bufs) + }, + State::Expecting { io, .. } => io.poll_write_vectored(cx, bufs), + State::Invalid => panic!("Negotiated: Invalid state"), } } } @@ -300,12 +359,12 @@ impl From for NegotiationError { } } -impl Into for NegotiationError { - fn into(self) -> io::Error { - if let NegotiationError::ProtocolError(e) = self { +impl From for io::Error { + fn from(err: NegotiationError) -> io::Error { + if let NegotiationError::ProtocolError(e) = err { return e.into() } - io::Error::new(io::ErrorKind::Other, self) + io::Error::new(io::ErrorKind::Other, err) } } @@ -333,27 +392,33 @@ impl fmt::Display for NegotiationError { mod tests { use super::*; use quickcheck::*; - use std::io::Write; + use std::{io::Write, task::Poll}; /// An I/O resource with a fixed write capacity (total and per write op). struct Capped { buf: Vec, step: usize } - impl io::Write for Capped { - fn write(&mut self, buf: &[u8]) -> io::Result { + impl AsyncRead for Capped { + fn poll_read(self: Pin<&mut Self>, _: &mut Context, _: &mut [u8]) -> Poll> { + unreachable!() + } + } + + impl AsyncWrite for Capped { + fn poll_write(mut self: Pin<&mut Self>, _: &mut Context, buf: &[u8]) -> Poll> { if self.buf.len() + buf.len() > self.buf.capacity() { - return Err(io::ErrorKind::WriteZero.into()) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) } - self.buf.write(&buf[.. usize::min(self.step, buf.len())]) + let len = usize::min(self.step, buf.len()); + let n = Write::write(&mut self.buf, &buf[.. len]).unwrap(); + Poll::Ready(Ok(n)) } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } - } - impl AsyncWrite for Capped { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(().into()) + fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } } @@ -369,7 +434,7 @@ mod tests { loop { // Write until `new` has been fully written or the capped buffer runs // over capacity and yields WriteZero. - match io.write(&new[written..]) { + match future::poll_fn(|cx| Pin::new(&mut io).poll_write(cx, &new[written..])).now_or_never().unwrap() { Ok(n) => if let State::Completed { remaining, .. } = &io.state { assert!(remaining.is_empty()); @@ -388,7 +453,7 @@ mod tests { return TestResult::failed() } } - Err(e) => panic!("Unexpected error: {:?}", e) + Err(e) => panic!("Unexpected error: {:?}", e), } } } diff --git a/misc/multistream-select/src/protocol.rs b/misc/multistream-select/src/protocol.rs index 55e2943a6ef..1c184a31bf7 100644 --- a/misc/multistream-select/src/protocol.rs +++ b/misc/multistream-select/src/protocol.rs @@ -25,12 +25,11 @@ //! `Stream` and `Sink` implementations of `MessageIO` and //! `MessageReader`. -use bytes::{Bytes, BytesMut, BufMut}; use crate::length_delimited::{LengthDelimited, LengthDelimitedReader}; -use futures::{prelude::*, try_ready}; -use log::trace; -use std::{io, fmt, error::Error, convert::TryFrom}; -use tokio_io::{AsyncRead, AsyncWrite}; + +use bytes::{Bytes, BytesMut, BufMut}; +use futures::{prelude::*, io::IoSlice, ready}; +use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}}; use unsigned_varint as uvi; /// The maximum number of supported protocols that can be processed. @@ -264,7 +263,9 @@ impl Message { } /// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s. +#[pin_project::pin_project] pub struct MessageIO { + #[pin] inner: LengthDelimited, } @@ -277,8 +278,8 @@ impl MessageIO { Self { inner: LengthDelimited::new(inner) } } - /// Converts the `MessageIO` into a `MessageReader`, dropping the - /// `Message`-oriented `Sink` in favour of direct `AsyncWrite` access + /// Converts the [`MessageIO`] into a [`MessageReader`], dropping the + /// [`Message`]-oriented `Sink` in favour of direct `AsyncWrite` access /// to the underlying I/O stream. /// /// This is typically done if further negotiation messages are expected to be @@ -288,7 +289,7 @@ impl MessageIO { MessageReader { inner: self.inner.into_reader() } } - /// Drops the `MessageIO` resource, yielding the underlying I/O stream + /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream /// together with the remaining write buffer containing the protocol /// negotiation frame data that has not yet been written to the I/O stream. /// @@ -309,28 +310,28 @@ impl MessageIO { } } -impl Sink for MessageIO +impl Sink for MessageIO where R: AsyncWrite, { - type SinkItem = Message; - type SinkError = ProtocolError; + type Error = ProtocolError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().inner.poll_ready(cx).map_err(From::from) + } - fn start_send(&mut self, msg: Self::SinkItem) -> StartSend { + fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { let mut buf = BytesMut::new(); - msg.encode(&mut buf)?; - match self.inner.start_send(buf.freeze())? { - AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(msg)), - AsyncSink::Ready => Ok(AsyncSink::Ready), - } + item.encode(&mut buf)?; + self.project().inner.start_send(buf.freeze()).map_err(From::from) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - Ok(self.inner.poll_complete()?) + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().inner.poll_flush(cx).map_err(From::from) } - fn close(&mut self) -> Poll<(), Self::SinkError> { - Ok(self.inner.close()?) + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().inner.poll_close(cx).map_err(From::from) } } @@ -338,18 +339,24 @@ impl Stream for MessageIO where R: AsyncRead { - type Item = Message; - type Error = ProtocolError; - - fn poll(&mut self) -> Poll, Self::Error> { - poll_stream(&mut self.inner) + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match poll_stream(self.project().inner, cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(From::from(err)))), + } } } /// A `MessageReader` implements a `Stream` of `Message`s on an underlying /// I/O resource combined with direct `AsyncWrite` access. +#[pin_project::pin_project] #[derive(Debug)] pub struct MessageReader { + #[pin] inner: LengthDelimitedReader } @@ -373,60 +380,56 @@ impl MessageReader { pub fn into_inner(self) -> (R, BytesMut) { self.inner.into_inner() } - - /// Returns a reference to the underlying I/O stream. - pub fn inner_ref(&self) -> &R { - self.inner.inner_ref() - } } impl Stream for MessageReader where R: AsyncRead { - type Item = Message; - type Error = ProtocolError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - poll_stream(&mut self.inner) + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + poll_stream(self.project().inner, cx) } } -impl io::Write for MessageReader +impl AsyncWrite for MessageReader where - R: AsyncWrite + TInner: AsyncWrite { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.write(buf) + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + self.project().inner.poll_write(cx, buf) } - fn flush(&mut self) -> io::Result<()> { - self.inner.flush() + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().inner.poll_flush(cx) } -} -impl AsyncWrite for MessageReader -where - TInner: AsyncWrite -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.inner.shutdown() + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().inner.poll_close(cx) + } + + fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice]) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) } } -fn poll_stream(stream: &mut S) -> Poll, ProtocolError> +fn poll_stream(stream: Pin<&mut S>, cx: &mut Context) -> Poll>> where - S: Stream, + S: Stream>, { - let msg = if let Some(msg) = try_ready!(stream.poll()) { - Message::decode(msg)? + let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) { + match Message::decode(msg) { + Ok(m) => m, + Err(err) => return Poll::Ready(Some(Err(err))), + } } else { - return Ok(Async::Ready(None)) + return Poll::Ready(None) }; - trace!("Received message: {:?}", msg); + log::trace!("Received message: {:?}", msg); - Ok(Async::Ready(Some(msg))) + Poll::Ready(Some(Ok(msg))) } /// A protocol error. diff --git a/misc/multistream-select/src/tests.rs b/misc/multistream-select/src/tests.rs index 0f2a33abd00..c5ddd43ea3a 100644 --- a/misc/multistream-select/src/tests.rs +++ b/misc/multistream-select/src/tests.rs @@ -25,164 +25,156 @@ use crate::{Version, NegotiationError}; use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial}; use crate::{dialer_select_proto, listener_select_proto}; + +use async_std::net::{TcpListener, TcpStream}; use futures::prelude::*; -use tokio::runtime::current_thread::Runtime; -use tokio_tcp::{TcpListener, TcpStream}; -use tokio_io::io as nio; #[test] fn select_proto_basic() { - fn run(version: Version) { - let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + async fn run(version: Version) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener_addr = listener.local_addr().unwrap(); - let server = listener - .incoming() - .into_future() - .map(|s| s.0.unwrap()) - .map_err(|(e, _)| e.into()) - .and_then(move |connec| { - let protos = vec![b"/proto1", b"/proto2"]; - listener_select_proto(connec, protos) - }) - .and_then(|(proto, io)| { - nio::write_all(io, b"pong").from_err().map(move |_| proto) - }); - - let client = TcpStream::connect(&listener_addr) - .from_err() - .and_then(move |connec| { - let protos = vec![b"/proto3", b"/proto2"]; - dialer_select_proto(connec, protos, version) - }) - .and_then(|(proto, io)| { - nio::write_all(io, b"ping").from_err().map(move |(io, _)| (proto, io)) - }) - .and_then(|(proto, io)| { - nio::read_exact(io, [0; 4]).from_err().map(move |(_, msg)| { - assert_eq!(&msg, b"pong"); - proto - }) - }); - - let mut rt = Runtime::new().unwrap(); - let (dialer_chosen, listener_chosen) = - rt.block_on(client.join(server)).unwrap(); - - assert_eq!(dialer_chosen, b"/proto2"); - assert_eq!(listener_chosen, b"/proto2"); + let server = async_std::task::spawn(async move { + let connec = listener.accept().await.unwrap().0; + let protos = vec![b"/proto1", b"/proto2"]; + let (proto, mut io) = listener_select_proto(connec, protos).await.unwrap(); + assert_eq!(proto, b"/proto2"); + + let mut out = vec![0; 32]; + let n = io.read(&mut out).await.unwrap(); + out.truncate(n); + assert_eq!(out, b"ping"); + + io.write_all(b"pong").await.unwrap(); + io.flush().await.unwrap(); + }); + + let client = async_std::task::spawn(async move { + let connec = TcpStream::connect(&listener_addr).await.unwrap(); + let protos = vec![b"/proto3", b"/proto2"]; + let (proto, mut io) = dialer_select_proto(connec, protos.into_iter(), version) + .await.unwrap(); + assert_eq!(proto, b"/proto2"); + + io.write_all(b"ping").await.unwrap(); + io.flush().await.unwrap(); + + let mut out = vec![0; 32]; + let n = io.read(&mut out).await.unwrap(); + out.truncate(n); + assert_eq!(out, b"pong"); + }); + + server.await; + client.await; } - run(Version::V1); - run(Version::V1Lazy); + async_std::task::block_on(run(Version::V1)); + async_std::task::block_on(run(Version::V1Lazy)); } #[test] fn no_protocol_found() { - fn run(version: Version) { - let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + async fn run(version: Version) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener_addr = listener.local_addr().unwrap(); - let server = listener - .incoming() - .into_future() - .map(|s| s.0.unwrap()) - .map_err(|(e, _)| e.into()) - .and_then(move |connec| { - let protos = vec![b"/proto1", b"/proto2"]; - listener_select_proto(connec, protos) - }) - .and_then(|(proto, io)| io.complete().map(move |_| proto)); - - let client = TcpStream::connect(&listener_addr) - .from_err() - .and_then(move |connec| { - let protos = vec![b"/proto3", b"/proto4"]; - dialer_select_proto(connec, protos, version) - }) - .and_then(|(proto, io)| io.complete().map(move |_| proto)); - - let mut rt = Runtime::new().unwrap(); - match rt.block_on(client.join(server)) { - Err(NegotiationError::Failed) => (), - e => panic!("{:?}", e), - } + let server = async_std::task::spawn(async move { + let connec = listener.accept().await.unwrap().0; + let protos = vec![b"/proto1", b"/proto2"]; + let io = match listener_select_proto(connec, protos).await { + Ok((_, io)) => io, + // We don't explicitly check for `Failed` because the client might close the connection when it + // realizes that we have no protocol in common. + Err(_) => return, + }; + match io.complete().await { + Err(NegotiationError::Failed) => {}, + _ => panic!(), + } + }); + + let client = async_std::task::spawn(async move { + let connec = TcpStream::connect(&listener_addr).await.unwrap(); + let protos = vec![b"/proto3", b"/proto4"]; + let io = match dialer_select_proto(connec, protos.into_iter(), version).await { + Err(NegotiationError::Failed) => return, + Ok((_, io)) => io, + Err(_) => panic!() + }; + match io.complete().await { + Err(NegotiationError::Failed) => {}, + _ => panic!(), + } + }); + + server.await; + client.await; } - run(Version::V1); - run(Version::V1Lazy); + async_std::task::block_on(run(Version::V1)); + async_std::task::block_on(run(Version::V1Lazy)); } #[test] fn select_proto_parallel() { - fn run(version: Version) { - let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + async fn run(version: Version) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener_addr = listener.local_addr().unwrap(); - let server = listener - .incoming() - .into_future() - .map(|s| s.0.unwrap()) - .map_err(|(e, _)| e.into()) - .and_then(move |connec| { - let protos = vec![b"/proto1", b"/proto2"]; - listener_select_proto(connec, protos) - }) - .and_then(|(proto, io)| io.complete().map(move |_| proto)); - - let client = TcpStream::connect(&listener_addr) - .from_err() - .and_then(move |connec| { - let protos = vec![b"/proto3", b"/proto2"]; - dialer_select_proto_parallel(connec, protos.into_iter(), version) - }) - .and_then(|(proto, io)| io.complete().map(move |_| proto)); - - let mut rt = Runtime::new().unwrap(); - let (dialer_chosen, listener_chosen) = - rt.block_on(client.join(server)).unwrap(); - - assert_eq!(dialer_chosen, b"/proto2"); - assert_eq!(listener_chosen, b"/proto2"); + let server = async_std::task::spawn(async move { + let connec = listener.accept().await.unwrap().0; + let protos = vec![b"/proto1", b"/proto2"]; + let (proto, io) = listener_select_proto(connec, protos).await.unwrap(); + assert_eq!(proto, b"/proto2"); + io.complete().await.unwrap(); + }); + + let client = async_std::task::spawn(async move { + let connec = TcpStream::connect(&listener_addr).await.unwrap(); + let protos = vec![b"/proto3", b"/proto2"]; + let (proto, io) = dialer_select_proto_parallel(connec, protos.into_iter(), version) + .await.unwrap(); + assert_eq!(proto, b"/proto2"); + io.complete().await.unwrap(); + }); + + server.await; + client.await; } - run(Version::V1); - run(Version::V1Lazy); + async_std::task::block_on(run(Version::V1)); + async_std::task::block_on(run(Version::V1Lazy)); } #[test] fn select_proto_serial() { - fn run(version: Version) { - let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + async fn run(version: Version) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener_addr = listener.local_addr().unwrap(); - let server = listener - .incoming() - .into_future() - .map(|s| s.0.unwrap()) - .map_err(|(e, _)| e.into()) - .and_then(move |connec| { - let protos = vec![b"/proto1", b"/proto2"]; - listener_select_proto(connec, protos) - }) - .and_then(|(proto, io)| io.complete().map(move |_| proto)); - - let client = TcpStream::connect(&listener_addr) - .from_err() - .and_then(move |connec| { - let protos = vec![b"/proto3", b"/proto2"]; - dialer_select_proto_serial(connec, protos.into_iter(), version) - }) - .and_then(|(proto, io)| io.complete().map(move |_| proto)); - - let mut rt = Runtime::new().unwrap(); - let (dialer_chosen, listener_chosen) = - rt.block_on(client.join(server)).unwrap(); - - assert_eq!(dialer_chosen, b"/proto2"); - assert_eq!(listener_chosen, b"/proto2"); + let server = async_std::task::spawn(async move { + let connec = listener.accept().await.unwrap().0; + let protos = vec![b"/proto1", b"/proto2"]; + let (proto, io) = listener_select_proto(connec, protos).await.unwrap(); + assert_eq!(proto, b"/proto2"); + io.complete().await.unwrap(); + }); + + let client = async_std::task::spawn(async move { + let connec = TcpStream::connect(&listener_addr).await.unwrap(); + let protos = vec![b"/proto3", b"/proto2"]; + let (proto, io) = dialer_select_proto_serial(connec, protos.into_iter(), version) + .await.unwrap(); + assert_eq!(proto, b"/proto2"); + io.complete().await.unwrap(); + }); + + server.await; + client.await; } - run(Version::V1); - run(Version::V1Lazy); + async_std::task::block_on(run(Version::V1)); + async_std::task::block_on(run(Version::V1Lazy)); } From 24fe550e591d2c38b43386bb7a41e53c40867401 Mon Sep 17 00:00:00 2001 From: Pierre Krieger Date: Thu, 5 Mar 2020 18:20:41 +0100 Subject: [PATCH 2/2] Fix intradoc links --- misc/multistream-select/src/length_delimited.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 42e82a85040..f75d9703b0e 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -299,10 +299,10 @@ impl LengthDelimitedReader { /// # Panic /// /// Will panic if called while there is data in the read or write buffer. - /// The read buffer is guaranteed to be empty whenever [`Stream::poll`] yields - /// a new `Message`. The write buffer is guaranteed to be empty whenever - /// [`LengthDelimited::poll_write_buffer`] yields [`Async::Ready`] or after - /// the [`Sink`] has been completely flushed via [`Sink::poll_complete`]. + /// The read buffer is guaranteed to be empty whenever [`Stream::poll_next`] + /// yield a new `Message`. The write buffer is guaranteed to be empty whenever + /// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after + /// the [`Sink`] has been completely flushed via [`Sink::poll_flush`]. pub fn into_inner(self) -> (R, BytesMut) { self.inner.into_inner() }