diff --git a/Cargo.lock b/Cargo.lock index 6358fca2ea..04c9d7f378 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7798,14 +7798,16 @@ dependencies = [ [[package]] name = "yamux" -version = "0.10.2" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5d9ba232399af1783a58d8eb26f6b5006fbefe2dc9ef36bd283324792d03ea5" +checksum = "5f97202f6b125031b95d83e01dc57292b529384f80bfae4677e4bbc10178cf72" dependencies = [ "futures 0.3.29", + "instant", "log", "nohash-hasher", "parking_lot 0.12.1", + "pin-project 1.1.3", "rand", "static_assertions", ] diff --git a/comms/core/Cargo.toml b/comms/core/Cargo.toml index 14e6a0aa47..fba684ac4c 100644 --- a/comms/core/Cargo.toml +++ b/comms/core/Cargo.toml @@ -12,8 +12,8 @@ edition = "2018" [dependencies] tari_crypto = { version = "0.20" } tari_metrics = { path = "../../infrastructure/metrics", optional = true, version = "1.0.0-pre.12" } -tari_storage = { path = "../../infrastructure/storage", version = "1.0.0-pre.12" } -tari_shutdown = { path = "../../infrastructure/shutdown" , version = "1.0.0-pre.12"} +tari_storage = { path = "../../infrastructure/storage", version = "1.0.0-pre.12" } +tari_shutdown = { path = "../../infrastructure/shutdown", version = "1.0.0-pre.12" } tari_utilities = { version = "0.7" } anyhow = "1.0.53" @@ -44,13 +44,13 @@ thiserror = "1.0.26" tokio = { version = "1.36", features = ["rt-multi-thread", "time", "sync", "signal", "net", "macros", "io-util"] } tokio-stream = { version = "0.1.9", features = ["sync"] } tokio-util = { version = "0.6.7", features = ["codec", "compat"] } -tower = {version = "0.4", features = ["util"]} +tower = { version = "0.4", features = ["util"] } tracing = "0.1.26" -yamux = "=0.10.2" +yamux = "0.13.2" zeroize = "1" [dev-dependencies] -tari_test_utils = { path = "../../infrastructure/test_utils" } +tari_test_utils = { path = "../../infrastructure/test_utils" } tari_comms_rpc_macros = { path = "../rpc_macros" } env_logger = "0.7.0" @@ -58,7 +58,7 @@ serde_json = "1.0.39" tempfile = "3.1.0" [build-dependencies] -tari_common = { path = "../../common", features = ["build"], version = "1.0.0-pre.12" } +tari_common = { path = "../../common", features = ["build"], version = "1.0.0-pre.12" } [features] c_integration = [] diff --git a/comms/core/src/connection_manager/listener.rs b/comms/core/src/connection_manager/listener.rs index 0631692101..937b1d9f8d 100644 --- a/comms/core/src/connection_manager/listener.rs +++ b/comms/core/src/connection_manager/listener.rs @@ -89,7 +89,7 @@ pub struct PeerListener { impl PeerListener where TTransport: Transport + Clone + Send + Sync + 'static, - TTransport::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static, + TTransport::Output: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { pub fn new( config: ConnectionManagerConfig, diff --git a/comms/core/src/multiplexing/yamux.rs b/comms/core/src/multiplexing/yamux.rs index 5e16dfc459..0730809b5a 100644 --- a/comms/core/src/multiplexing/yamux.rs +++ b/comms/core/src/multiplexing/yamux.rs @@ -20,15 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{future::Future, io, pin::Pin, task::Poll}; +use std::{future::poll_fn, io, marker::PhantomData, pin::Pin, task::Poll}; -use futures::{task::Context, Stream}; +use futures::{channel::oneshot, task::Context, Stream}; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, sync::mpsc, }; use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; -use tracing::{self, debug, error}; +use tracing::{self, debug, error, warn}; // Reexport pub use yamux::ConnectionError; use yamux::Mode; @@ -48,30 +48,20 @@ pub struct Yamux { substream_counter: AtomicRefCounter, } -const MAX_BUFFER_SIZE: u32 = 8 * 1024 * 1024; // 8MiB -const RECEIVE_WINDOW: u32 = 5 * 1024 * 1024; // 5MiB - impl Yamux { /// Upgrade the underlying socket to use yamux pub fn upgrade_connection(socket: TSocket, direction: ConnectionDirection) -> io::Result - where TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static { + where TSocket: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static { let mode = match direction { ConnectionDirection::Inbound => Mode::Server, ConnectionDirection::Outbound => Mode::Client, }; - let mut config = yamux::Config::default(); - - config.set_window_update_mode(yamux::WindowUpdateMode::OnRead); - // Because OnRead mode increases the RTT of window update, bigger buffer size and receive - // window size perform better. - config.set_max_buffer_size(MAX_BUFFER_SIZE as usize); - config.set_receive_window(RECEIVE_WINDOW); + let config = yamux::Config::default(); let substream_counter = AtomicRefCounter::new(); let connection = yamux::Connection::new(socket.compat(), config, mode); - let control = Control::new(connection.control(), substream_counter.clone()); - let incoming = Self::spawn_incoming_stream_worker(connection, substream_counter.clone()); + let (control, incoming) = Self::spawn_incoming_stream_worker(connection, substream_counter.clone()); Ok(Self { control, @@ -85,14 +75,16 @@ impl Yamux { fn spawn_incoming_stream_worker( connection: yamux::Connection, counter: AtomicRefCounter, - ) -> IncomingSubstreams + ) -> (Control, IncomingSubstreams) where - TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static, + TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + Sync + 'static, { let (incoming_tx, incoming_rx) = mpsc::channel(10); - let incoming = IncomingWorker::new(connection, incoming_tx); - tokio::spawn(incoming.run()); - IncomingSubstreams::new(incoming_rx, counter) + let (request_tx, request_rx) = mpsc::channel(1); + let incoming = YamuxWorker::new(incoming_tx, request_rx, counter.clone()); + let control = Control::new(request_tx); + tokio::spawn(incoming.run(connection)); + (control, IncomingSubstreams::new(incoming_rx, counter)) } /// Get the yamux control struct @@ -121,42 +113,45 @@ impl Yamux { } } +#[derive(Debug)] +pub enum YamuxRequest { + OpenStream { + reply: oneshot::Sender>, + }, + Close { + reply: oneshot::Sender>, + }, +} + #[derive(Clone)] pub struct Control { - inner: yamux::Control, - substream_counter: AtomicRefCounter, + request_tx: mpsc::Sender, } impl Control { - pub fn new(inner: yamux::Control, substream_counter: AtomicRefCounter) -> Self { - Self { - inner, - substream_counter, - } + pub fn new(request_tx: mpsc::Sender) -> Self { + Self { request_tx } } /// Open a new stream to the remote. pub async fn open_stream(&mut self) -> Result { - // Ensure that this counts as used while the substream is being opened - let counter_guard = self.substream_counter.new_guard(); - let stream = self.inner.open_stream().await?; - Ok(Substream { - stream: stream.compat(), - _counter_guard: counter_guard, - }) + let (reply, reply_rx) = oneshot::channel(); + self.request_tx + .send(YamuxRequest::OpenStream { reply }) + .await + .map_err(|_| ConnectionError::Closed)?; + let stream = reply_rx.await.map_err(|_| ConnectionError::Closed)??; + Ok(stream) } /// Close the connection. - pub fn close(&mut self) -> impl Future> + '_ { - self.inner.close() - } - - pub fn substream_count(&self) -> usize { - self.substream_counter.get() - } - - pub(crate) fn substream_counter(&self) -> AtomicRefCounter { - self.substream_counter.clone() + pub async fn close(&mut self) -> Result<(), ConnectionError> { + let (reply, reply_rx) = oneshot::channel(); + self.request_tx + .send(YamuxRequest::Close { reply }) + .await + .map_err(|_| ConnectionError::Closed)?; + reply_rx.await.map_err(|_| ConnectionError::Closed)? } } @@ -240,52 +235,78 @@ impl From for stream_id::Id { } } -struct IncomingWorker { - connection: yamux::Connection, - sender: mpsc::Sender, +struct YamuxWorker { + incoming_substreams: mpsc::Sender, + request_rx: mpsc::Receiver, + counter: AtomicRefCounter, + _phantom: PhantomData, } -impl IncomingWorker -where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static /* */ +impl YamuxWorker +where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + Sync + 'static { - pub fn new(connection: yamux::Connection, sender: mpsc::Sender) -> Self { - Self { connection, sender } + pub fn new( + incoming_substreams: mpsc::Sender, + request_rx: mpsc::Receiver, + counter: AtomicRefCounter, + ) -> Self { + Self { + incoming_substreams, + request_rx, + counter, + _phantom: PhantomData, + } } - pub async fn run(mut self) { + async fn run(mut self, mut connection: yamux::Connection) { loop { tokio::select! { - _ = self.sender.closed() => { - self.close().await; + biased; + + _ = self.incoming_substreams.closed() => { + debug!( + target: LOG_TARGET, + "{} Incoming peer substream task is stopping because the internal stream sender channel was \ + closed", + self.counter.get() + ); + // Ignore: we already log the error variant in Self::close + let _ignore = Self::close(&mut connection).await; break }, - result = self.connection.next_stream() => { + Some(request) = self.request_rx.recv() => { + if let Err(err) = self.handle_request(&mut connection, request).await { + error!(target: LOG_TARGET, "Error handling request: {err}"); + break; + } + }, + + result = Self::next_inbound_stream(&mut connection) => { match result { - Ok(Some(stream)) => { - if self.sender.send(stream).await.is_err() { + Some(Ok(stream)) => { + if self.incoming_substreams.send(stream).await.is_err() { debug!( target: LOG_TARGET, - "{} Incoming peer substream task is stopping because the internal stream sender channel \ - was closed", - self.connection + "{} Incoming peer substream task is stopping because the internal stream sender channel was closed", + self.counter.get() ); break; } }, - Ok(None) =>{ + None =>{ debug!( target: LOG_TARGET, "{} Incoming peer substream ended.", - self.connection + self.counter.get() ); break; } - Err(err) => { + Some(Err(err)) => { error!( target: LOG_TARGET, "{} Incoming peer substream task received an error because '{}'", - self.connection, + self.counter.get(), err ); break; @@ -296,38 +317,46 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static } } - async fn close(&mut self) { - let mut control = self.connection.control(); - // Sends the close message once polled, while continuing to poll the connection future - let close_fut = control.close(); - tokio::pin!(close_fut); - loop { - tokio::select! { - biased; + async fn handle_request( + &self, + connection_mut: &mut yamux::Connection, + request: YamuxRequest, + ) -> io::Result<()> { + match request { + YamuxRequest::OpenStream { reply } => { + let result = poll_fn(move |cx| connection_mut.poll_new_outbound(cx)).await; + if reply + .send(result.map(|stream| Substream { + stream: stream.compat(), + _counter_guard: self.counter.new_guard(), + })) + .is_err() + { + warn!(target: LOG_TARGET, "Request to open substream was aborted before reply was sent"); + } + }, + YamuxRequest::Close { reply } => { + if reply.send(Self::close(connection_mut).await).is_err() { + warn!(target: LOG_TARGET, "Request to close substream was aborted before reply was sent"); + } + }, + } + Ok(()) + } - result = &mut close_fut => { - match result { - Ok(_) => break, - Err(err) => { - error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err); - break; - } - } - }, + async fn next_inbound_stream( + connection_mut: &mut yamux::Connection, + ) -> Option> { + poll_fn(|cx| connection_mut.poll_next_inbound(cx)).await + } - result = self.connection.next_stream() => { - match result { - Ok(Some(_)) => continue, - Ok(None) => break, - Err(err) => { - error!(target: LOG_TARGET, "Error while closing yamux connection: {}", err); - continue; - } - } - } - } + async fn close(connection: &mut yamux::Connection) -> yamux::Result<()> { + if let Err(err) = poll_fn(|cx| connection.poll_close(cx)).await { + error!(target: LOG_TARGET, "Error while closing yamux connection: {}", err); + return Err(err); } - debug!(target: LOG_TARGET, "{} Yamux connection has closed", self.connection); + debug!(target: LOG_TARGET, "Yamux connection has closed"); + Ok(()) } } @@ -356,21 +385,18 @@ mod test { let mut substream = dialer_control.open_stream().await.unwrap(); substream.write_all(msg).await.unwrap(); - substream.flush().await.unwrap(); substream.shutdown().await.unwrap(); }); - let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound)?.into_incoming(); + let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound)?; let mut substream = listener + .incoming .next() .await .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no substream"))?; let mut buf = Vec::new(); - tokio::select! { - _ = substream.read_to_end(&mut buf) => {}, - _ = listener.next() => {}, - }; + substream.read_to_end(&mut buf).await?; assert_eq!(buf, msg); Ok(()) @@ -387,15 +413,21 @@ mod test { let substreams_out = tokio::spawn(async move { let mut substreams = Vec::with_capacity(NUM_SUBSTREAMS); for _ in 0..NUM_SUBSTREAMS { - substreams.push(dialer_control.open_stream().await.unwrap()); + let mut stream = dialer_control.open_stream().await.unwrap(); + // Since Yamux 0.12.0 the client does not initiate a substream unless you actually write something + stream.write_all(b"hello").await.unwrap(); + substreams.push(stream); } substreams }); - let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) - .unwrap() - .into_incoming(); - let substreams_in = collect_stream!(&mut listener, take = NUM_SUBSTREAMS, timeout = Duration::from_secs(10)); + let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound).unwrap(); + + let substreams_in = collect_stream!( + &mut listener.incoming, + take = NUM_SUBSTREAMS, + timeout = Duration::from_secs(10) + ); assert_eq!(dialer.substream_count(), NUM_SUBSTREAMS); assert_eq!(listener.substream_count(), NUM_SUBSTREAMS); @@ -426,8 +458,8 @@ mod test { assert_eq!(buf, b""); }); - let mut incoming = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound)?.into_incoming(); - let mut substream = incoming.next().await.unwrap(); + let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound)?; + let mut substream = listener.incoming.next().await.unwrap(); let mut buf = vec![0; msg.len()]; substream.read_exact(&mut buf).await?; @@ -482,12 +514,13 @@ mod test { let (dialer, listener) = MemorySocket::new_pair(); let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound)?; + let substream_counter = dialer.substream_counter(); let mut dialer_control = dialer.get_yamux_control(); tokio::spawn(async move { - assert_eq!(dialer_control.substream_count(), 0); + assert_eq!(substream_counter.get(), 0); let mut substream = dialer_control.open_stream().await.unwrap(); - assert_eq!(dialer_control.substream_count(), 1); + assert_eq!(substream_counter.get(), 1); let msg = vec![0x55u8; MSG_LEN]; substream.write_all(msg.as_slice()).await.unwrap(); @@ -500,10 +533,10 @@ mod test { assert_eq!(buf, vec![0xAAu8; MSG_LEN]); }); - let mut incoming = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound)?.into_incoming(); - assert_eq!(incoming.substream_count(), 0); - let mut substream = incoming.next().await.unwrap(); - assert_eq!(incoming.substream_count(), 1); + let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound)?; + assert_eq!(listener.substream_count(), 0); + let mut substream = listener.incoming.next().await.unwrap(); + assert_eq!(listener.substream_count(), 1); let mut buf = vec![0u8; MSG_LEN]; substream.read_exact(&mut buf).await?; @@ -514,7 +547,7 @@ mod test { substream.shutdown().await?; drop(substream); - assert_eq!(incoming.substream_count(), 0); + assert_eq!(listener.substream_count(), 0); Ok(()) } diff --git a/comms/core/src/protocol/messaging/inbound.rs b/comms/core/src/protocol/messaging/inbound.rs index d0a1eeaa31..895ba1db32 100644 --- a/comms/core/src/protocol/messaging/inbound.rs +++ b/comms/core/src/protocol/messaging/inbound.rs @@ -22,8 +22,9 @@ use std::io; -use futures::StreamExt; +use futures::{future::Either, SinkExt, StreamExt}; use log::*; +use tari_shutdown::ShutdownSignal; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::{broadcast, mpsc}, @@ -32,7 +33,7 @@ use tokio::{ #[cfg(feature = "metrics")] use super::metrics; use super::{MessagingEvent, MessagingProtocol}; -use crate::{message::InboundMessage, peer_manager::NodeId}; +use crate::{message::InboundMessage, peer_manager::NodeId, protocol::rpc::__macro_reexports::future}; const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; @@ -42,6 +43,7 @@ pub struct InboundMessaging { inbound_message_tx: mpsc::Sender, messaging_events_tx: broadcast::Sender, enable_message_received_event: bool, + shutdown_signal: ShutdownSignal, } impl InboundMessaging { @@ -50,16 +52,18 @@ impl InboundMessaging { inbound_message_tx: mpsc::Sender, messaging_events_tx: broadcast::Sender, enable_message_received_event: bool, + shutdown_signal: ShutdownSignal, ) -> Self { Self { peer, inbound_message_tx, messaging_events_tx, enable_message_received_event, + shutdown_signal, } } - pub async fn run(self, socket: S) + pub async fn run(mut self, socket: S) where S: AsyncRead + AsyncWrite + Unpin { let peer = &self.peer; #[cfg(feature = "metrics")] @@ -71,10 +75,9 @@ impl InboundMessaging { ); let stream = MessagingProtocol::framed(socket); - tokio::pin!(stream); - while let Some(result) = stream.next().await { + while let Either::Right((Some(result), _)) = future::select(self.shutdown_signal.wait(), stream.next()).await { match result { Ok(raw_msg) => { #[cfg(feature = "metrics")] @@ -138,6 +141,8 @@ impl InboundMessaging { } } + let _ignore = stream.close().await; + let _ignore = self .messaging_events_tx .send(MessagingEvent::InboundProtocolExited(peer.clone())); diff --git a/comms/core/src/protocol/messaging/protocol.rs b/comms/core/src/protocol/messaging/protocol.rs index ac7f69ab70..0fc6a7d5a9 100644 --- a/comms/core/src/protocol/messaging/protocol.rs +++ b/comms/core/src/protocol/messaging/protocol.rs @@ -351,6 +351,7 @@ impl MessagingProtocol { inbound_message_tx, messaging_events_tx, self.enable_message_received_event, + self.shutdown_signal.clone(), ); let handle = tokio::spawn(inbound_messaging.run(substream)); self.active_inbound.insert(peer, handle); diff --git a/comms/core/src/protocol/messaging/test.rs b/comms/core/src/protocol/messaging/test.rs index 248d33a4d2..2a27e9e7ca 100644 --- a/comms/core/src/protocol/messaging/test.rs +++ b/comms/core/src/protocol/messaging/test.rs @@ -55,6 +55,7 @@ use crate::{ }; static TEST_MSG1: Bytes = Bytes::from_static(b"TEST_MSG1"); +static TEST_MSG2: Bytes = Bytes::from_static(b"TEST_MSG2"); static MESSAGING_PROTOCOL_ID: ProtocolId = ProtocolId::from_static(b"test/msg"); @@ -128,21 +129,21 @@ async fn new_inbound_substream_handling() { // Create connected memory sockets - we use each end of the connection as if they exist on different nodes let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await; - // Notify the messaging protocol that a new substream has been established that wants to talk the messaging. let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + + let mut framed_ours = MessagingProtocol::framed(stream_ours); + framed_ours.send(TEST_MSG1.clone()).await.unwrap(); + + // Notify the messaging protocol that a new substream has been established that wants to talk the messaging. + let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap(); proto_tx .send(ProtocolNotification::new( MESSAGING_PROTOCOL_ID.clone(), - ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours), + ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_theirs), )) .await .unwrap(); - let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap(); - let mut framed_theirs = MessagingProtocol::framed(stream_theirs); - - framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); - let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) .await .unwrap() @@ -370,42 +371,55 @@ async fn new_inbound_substream_only_single_session_permitted() { // Create connected memory sockets - we use each end of the connection as if they exist on different nodes let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await; + // Spawn a task to deal with incoming substreams + tokio::spawn({ + let expected_node_id = expected_node_id.clone(); + async move { + while let Some(stream_theirs) = muxer_theirs.incoming_mut().next().await { + proto_tx + .send(ProtocolNotification::new( + MESSAGING_PROTOCOL_ID.clone(), + ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_theirs), + )) + .await + .unwrap(); + } + } + }); - // Notify the messaging protocol that a new substream has been established that wants to talk the messaging. + // Open first stream let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap(); - proto_tx - .send(ProtocolNotification::new( - MESSAGING_PROTOCOL_ID.clone(), - ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours), - )) - .await - .unwrap(); + let mut framed_ours = MessagingProtocol::framed(stream_ours); + framed_ours.send(TEST_MSG1.clone()).await.unwrap(); - // First stream is open - let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap(); - - // Open another one for messaging - let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap(); - proto_tx - .send(ProtocolNotification::new( - MESSAGING_PROTOCOL_ID.clone(), - ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours2), - )) + // Message comes through + let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) .await + .unwrap() .unwrap(); + assert_eq!(in_msg.source_peer, expected_node_id); + assert_eq!(in_msg.body, TEST_MSG1); // Check the second stream closes immediately - let stream_theirs2 = muxer_theirs.incoming_mut().next().await.unwrap(); - let mut framed_ours2 = MessagingProtocol::framed(stream_theirs2); - let next = framed_ours2.next().await; - // The stream is closed - assert!(next.is_none()); - - // The first stream is still active - let mut framed_theirs = MessagingProtocol::framed(stream_theirs); + let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap(); - framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); + let mut framed_ours2 = MessagingProtocol::framed(stream_ours2); + // Check that it eventually exits. The first send will initiate the substream and send. Once the other side closes + // the connection it takes a few sends for that to be detected and the substream to be closed. + loop { + // This message will not go through + if let Err(e) = framed_ours2.send(TEST_MSG2.clone()).await { + assert_eq!( + e.to_string().split(':').nth(1).map(|s| s.trim()), + Some("connection is closed"), + "Expected connection to be closed but got '{e}'" + ); + break; + } + } + // First stream still open + framed_ours.send(TEST_MSG1.clone()).await.unwrap(); let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) .await .unwrap() @@ -414,23 +428,14 @@ async fn new_inbound_substream_only_single_session_permitted() { assert_eq!(in_msg.body, TEST_MSG1); // Close the first - framed_theirs.close().await.unwrap(); + framed_ours.close().await.unwrap(); // Open another one for messaging - let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap(); - proto_tx - .send(ProtocolNotification::new( - MESSAGING_PROTOCOL_ID.clone(), - ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours2), - )) - .await - .unwrap(); - - let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap(); - let mut framed_theirs = MessagingProtocol::framed(stream_theirs); - framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); + let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + let mut framed_ours = MessagingProtocol::framed(stream_ours); + framed_ours.send(TEST_MSG1.clone()).await.unwrap(); - // The second message comes through + // The third message comes through let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) .await .unwrap() diff --git a/comms/core/src/protocol/rpc/test/smoke.rs b/comms/core/src/protocol/rpc/test/smoke.rs index f2c17e1e59..344d29d2e1 100644 --- a/comms/core/src/protocol/rpc/test/smoke.rs +++ b/comms/core/src/protocol/rpc/test/smoke.rs @@ -27,20 +27,24 @@ use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; use tari_utilities::hex::Hex; use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, sync::{mpsc, RwLock}, task, time, }; +use tokio_stream::Stream; use crate::{ framing, - multiplexing::Yamux, + multiplexing::{Control, Yamux}, + peer_manager::NodeId, protocol::{ rpc, rpc::{ context::RpcCommsBackend, error::HandshakeRejectReason, handshake::RpcHandshakeError, + server::NamedProtocolService, test::{ greeting_service::{ GreetingClient, @@ -114,32 +118,46 @@ pub(super) async fn setup_service( setup_service_with_builder(service_impl, builder).await } +fn spawn_inbound( + mut inbound: impl Stream + Unpin + Send + 'static, + notif_tx: mpsc::Sender>, + node_id: NodeId, +) -> task::JoinHandle<()> { + task::spawn(async move { + while let Some(stream) = inbound.next().await { + notif_tx + .send(ProtocolNotification::new( + ProtocolId::from_static(GreetingClient::PROTOCOL_NAME), + ProtocolEvent::NewInboundSubstream(node_id.clone(), stream), + )) + .await + .unwrap(); + } + }) +} + pub(super) async fn setup( service_impl: T, num_concurrent_sessions: usize, -) -> (Yamux, Yamux, task::JoinHandle<()>, Arc, Shutdown) { +) -> (Control, Yamux, task::JoinHandle<()>, Arc, Shutdown) { let (notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; let (_, inbound, outbound) = build_multiplexed_connections().await; - let substream = outbound.get_yamux_control().open_stream().await.unwrap(); + let inbound_control = inbound.get_yamux_control(); let node_identity = build_node_identity(Default::default()); + let node_id = node_identity.node_id().clone(); + spawn_inbound(inbound.into_incoming(), notif_tx.clone(), node_id); + // Notify that a peer wants to speak the greeting RPC protocol context.peer_manager().add_peer(node_identity.to_peer()).await.unwrap(); - notif_tx - .send(ProtocolNotification::new( - ProtocolId::from_static(b"/test/greeting/1.0"), - ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), - )) - .await - .unwrap(); - (inbound, outbound, server_hnd, node_identity, shutdown) + (inbound_control, outbound, server_hnd, node_identity, shutdown) } #[tokio::test] async fn request_response_errors_and_streaming() { - let (mut muxer, _outbound, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; - let socket = muxer.incoming_mut().next().await.unwrap(); + let (_inbound, outbound, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() @@ -221,8 +239,8 @@ async fn request_response_errors_and_streaming() { #[tokio::test] async fn concurrent_requests() { - let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::default(), 1).await; - let socket = muxer.incoming_mut().next().await.unwrap(); + let (_inbound, outbound, _, _, _shutdown) = setup(GreetingService::default(), 1).await; + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() @@ -261,8 +279,8 @@ async fn concurrent_requests() { #[tokio::test] async fn response_too_big() { - let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; - let socket = muxer.incoming_mut().next().await.unwrap(); + let (_inbound, outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, rpc::max_response_size()); let mut client = GreetingClient::builder() @@ -288,8 +306,8 @@ async fn response_too_big() { #[tokio::test] async fn ping_latency() { - let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; - let socket = muxer.incoming_mut().next().await.unwrap(); + let (_inbound, outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder().connect(framed).await.unwrap(); @@ -302,8 +320,8 @@ async fn ping_latency() { #[tokio::test] async fn server_shutdown_before_connect() { - let (mut muxer, _outbound, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; - let socket = muxer.incoming_mut().next().await.unwrap(); + let (_inbound, outbound, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); shutdown.trigger(); @@ -317,8 +335,8 @@ async fn server_shutdown_before_connect() { #[tokio::test] async fn timeout() { let delay = Arc::new(RwLock::new(Duration::from_secs(10))); - let (mut muxer, _outbound, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await; - let socket = muxer.incoming_mut().next().await.unwrap(); + let (_inbound, outbound, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await; + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() .with_deadline(Duration::from_secs(1)) @@ -344,7 +362,9 @@ async fn unknown_protocol() { let (notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; let (_, inbound, mut outbound) = build_multiplexed_connections().await; - let in_substream = inbound.get_yamux_control().open_stream().await.unwrap(); + let mut in_substream = inbound.get_yamux_control().open_stream().await.unwrap(); + // To avoid having to spawn a inbound task, we can just write to the stream directly to initiate a substream + in_substream.write_all(b"hello").await.unwrap(); let node_identity = build_node_identity(Default::default()); @@ -359,7 +379,9 @@ async fn unknown_protocol() { .await .unwrap(); - let out_socket = outbound.incoming_mut().next().await.unwrap(); + let mut out_socket = outbound.incoming_mut().next().await.unwrap(); + // Read "hello" + out_socket.read_exact(&mut [0u8; 5]).await.unwrap(); let framed = framing::canonical(out_socket, 1024); let err = GreetingClient::connect(framed).await.unwrap_err(); assert!(matches!( @@ -370,8 +392,8 @@ async fn unknown_protocol() { #[tokio::test] async fn rejected_no_sessions_available() { - let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await; - let socket = muxer.incoming_mut().next().await.unwrap(); + let (_inbound, outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await; + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let err = GreetingClient::builder().connect(framed).await.unwrap_err(); assert!(matches!( @@ -383,8 +405,8 @@ async fn rejected_no_sessions_available() { #[tokio::test] async fn stream_still_works_after_cancel() { let service_impl = GreetingService::default(); - let (mut muxer, _outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await; - let socket = muxer.incoming_mut().next().await.unwrap(); + let (_inbound, outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await; + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() @@ -423,8 +445,8 @@ async fn stream_still_works_after_cancel() { #[tokio::test] async fn stream_interruption_handling() { let service_impl = GreetingService::default(); - let (mut muxer, _outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await; - let socket = muxer.incoming_mut().next().await.unwrap(); + let (_inbound, outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await; + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() @@ -471,24 +493,15 @@ async fn stream_interruption_handling() { async fn max_global_sessions() { let builder = RpcServer::builder().with_maximum_simultaneous_sessions(1); let (muxer, _outbound, context, _shutdown) = setup_service_with_builder(GreetingService::default(), builder).await; - let (_, mut inbound, outbound) = build_multiplexed_connections().await; + let (_, inbound, outbound) = build_multiplexed_connections().await; let node_identity = build_node_identity(Default::default()); // Notify that a peer wants to speak the greeting RPC protocol context.peer_manager().add_peer(node_identity.to_peer()).await.unwrap(); - for _ in 0..2 { - let substream = outbound.get_yamux_control().open_stream().await.unwrap(); - muxer - .send(ProtocolNotification::new( - ProtocolId::from_static(b"/test/greeting/1.0"), - ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), - )) - .await - .unwrap(); - } + spawn_inbound(inbound.into_incoming(), muxer.clone(), node_identity.node_id().clone()); - let socket = inbound.incoming_mut().next().await.unwrap(); + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) @@ -496,7 +509,7 @@ async fn max_global_sessions() { .await .unwrap(); - let socket = inbound.incoming_mut().next().await.unwrap(); + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let err = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) @@ -508,15 +521,8 @@ async fn max_global_sessions() { unpack_enum!(RpcHandshakeError::Rejected(HandshakeRejectReason::NoSessionsAvailable) = err); client.close().await; - let substream = outbound.get_yamux_control().open_stream().await.unwrap(); - muxer - .send(ProtocolNotification::new( - ProtocolId::from_static(b"/test/greeting/1.0"), - ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), - )) - .await - .unwrap(); - let socket = inbound.incoming_mut().next().await.unwrap(); + + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let _client = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) @@ -531,23 +537,14 @@ async fn max_per_client_sessions() { .with_maximum_simultaneous_sessions(3) .with_maximum_sessions_per_client(1); let (muxer, _outbound, context, _shutdown) = setup_service_with_builder(GreetingService::default(), builder).await; - let (_, mut inbound, outbound) = build_multiplexed_connections().await; + let (_, inbound, outbound) = build_multiplexed_connections().await; let node_identity = build_node_identity(Default::default()); // Notify that a peer wants to speak the greeting RPC protocol context.peer_manager().add_peer(node_identity.to_peer()).await.unwrap(); - for _ in 0..2 { - let substream = outbound.get_yamux_control().open_stream().await.unwrap(); - muxer - .send(ProtocolNotification::new( - ProtocolId::from_static(b"/test/greeting/1.0"), - ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), - )) - .await - .unwrap(); - } + spawn_inbound(inbound.into_incoming(), muxer.clone(), node_identity.node_id().clone()); - let socket = inbound.incoming_mut().next().await.unwrap(); + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let client = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) @@ -555,7 +552,7 @@ async fn max_per_client_sessions() { .await .unwrap(); - let socket = inbound.incoming_mut().next().await.unwrap(); + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let err = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) @@ -567,15 +564,8 @@ async fn max_per_client_sessions() { unpack_enum!(RpcHandshakeError::Rejected(HandshakeRejectReason::NoSessionsAvailable) = err); drop(client); - let substream = outbound.get_yamux_control().open_stream().await.unwrap(); - muxer - .send(ProtocolNotification::new( - ProtocolId::from_static(b"/test/greeting/1.0"), - ProtocolEvent::NewInboundSubstream(node_identity.node_id().clone(), substream), - )) - .await - .unwrap(); - let socket = inbound.incoming_mut().next().await.unwrap(); + + let socket = outbound.get_yamux_control().open_stream().await.unwrap(); let framed = framing::canonical(socket, 1024); let _client = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) diff --git a/comms/core/src/test_utils/mocks/peer_connection.rs b/comms/core/src/test_utils/mocks/peer_connection.rs index b288339ef8..dbd67f5045 100644 --- a/comms/core/src/test_utils/mocks/peer_connection.rs +++ b/comms/core/src/test_utils/mocks/peer_connection.rs @@ -33,6 +33,7 @@ use tokio::{ sync::{mpsc, Mutex}, }; use tokio_stream::StreamExt; +use yamux::ConnectionError; use crate::{ connection_manager::{ @@ -138,7 +139,7 @@ pub struct PeerConnectionMockState { impl PeerConnectionMockState { pub fn new(muxer: Yamux) -> Self { let control = muxer.get_yamux_control(); - let substream_counter = control.substream_counter(); + let substream_counter = muxer.substream_counter(); Self { call_count: Arc::new(AtomicUsize::new(0)), mux_control: Arc::new(Mutex::new(control)), @@ -172,7 +173,12 @@ impl PeerConnectionMockState { } pub async fn disconnect(&self) -> Result<(), PeerConnectionError> { - self.mux_control.lock().await.close().await.map_err(Into::into) + match self.mux_control.lock().await.close().await { + Ok(_) => Ok(()), + // Match the behaviour of the real PeerConnection. + Err(ConnectionError::Closed) => Ok(()), + Err(err) => Err(err.into()), + } } } diff --git a/comms/core/src/test_utils/transport.rs b/comms/core/src/test_utils/transport.rs index 4dd4619c49..2d7daebab2 100644 --- a/comms/core/src/test_utils/transport.rs +++ b/comms/core/src/test_utils/transport.rs @@ -39,10 +39,7 @@ pub async fn build_connected_sockets() -> (Multiaddr, MemorySocket, MemorySocket pub async fn build_multiplexed_connections() -> (Multiaddr, Yamux, Yamux) { let (addr, socket_out, socket_in) = build_connected_sockets().await; - let muxer_out = Yamux::upgrade_connection(socket_out, ConnectionDirection::Outbound).unwrap(); - let muxer_in = Yamux::upgrade_connection(socket_in, ConnectionDirection::Inbound).unwrap(); - (addr, muxer_out, muxer_in) }