diff --git a/Cargo.lock b/Cargo.lock index 73e0804451..9e7cc3d47b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7890,6 +7890,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.6" @@ -8215,18 +8225,18 @@ dependencies = [ [[package]] name = "yamux" -version = "0.13.2" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f97202f6b125031b95d83e01dc57292b529384f80bfae4677e4bbc10178cf72" +checksum = "a31b5e376a8b012bee9c423acdbb835fc34d45001cfa3106236a624e4b738028" dependencies = [ "futures 0.3.29", - "instant", "log", "nohash-hasher", "parking_lot 0.12.1", "pin-project 1.1.3", "rand", "static_assertions", + "web-time", ] [[package]] diff --git a/comms/core/src/protocol/messaging/error.rs b/comms/core/src/protocol/messaging/error.rs index eaf02e1672..ab2a4c530b 100644 --- a/comms/core/src/protocol/messaging/error.rs +++ b/comms/core/src/protocol/messaging/error.rs @@ -53,6 +53,8 @@ pub enum MessagingProtocolError { PeerConnectionError(#[from] PeerConnectionError), #[error("Failed to dial peer: {0}")] PeerDialFailed(ConnectivityError), + #[error("Connectivity error: {0}")] + ConnectivityError(#[from] ConnectivityError), #[error("IO Error: {0}")] Io(io::Error), #[error("Sender error: {0}")] diff --git a/comms/core/src/protocol/messaging/inbound.rs b/comms/core/src/protocol/messaging/inbound.rs index f2c393af86..97b3f545ff 100644 --- a/comms/core/src/protocol/messaging/inbound.rs +++ b/comms/core/src/protocol/messaging/inbound.rs @@ -33,13 +33,13 @@ use tokio::{ #[cfg(feature = "metrics")] use super::metrics; use super::{MessagingEvent, MessagingProtocol}; -use crate::{message::InboundMessage, peer_manager::NodeId}; +use crate::{message::InboundMessage, PeerConnection}; const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; /// Inbound messaging actor. This is lazily spawned per peer when a peer requests a messaging session. pub struct InboundMessaging { - peer: NodeId, + connection: PeerConnection, inbound_message_tx: mpsc::Sender, messaging_events_tx: broadcast::Sender, enable_message_received_event: bool, @@ -48,14 +48,14 @@ pub struct InboundMessaging { impl InboundMessaging { pub fn new( - peer: NodeId, + connection: PeerConnection, inbound_message_tx: mpsc::Sender, messaging_events_tx: broadcast::Sender, enable_message_received_event: bool, shutdown_signal: ShutdownSignal, ) -> Self { Self { - peer, + connection, inbound_message_tx, messaging_events_tx, enable_message_received_event, @@ -65,7 +65,7 @@ impl InboundMessaging { pub async fn run(mut self, socket: S) where S: AsyncRead + AsyncWrite + Unpin { - let peer = &self.peer; + let peer = self.connection.peer_node_id(); #[cfg(feature = "metrics")] metrics::num_sessions().inc(); debug!( @@ -75,13 +75,14 @@ impl InboundMessaging { ); let stream = MessagingProtocol::framed(socket); + let stream = stream.take_until(self.connection.on_disconnect()); tokio::pin!(stream); while let Either::Right((Some(result), _)) = future::select(self.shutdown_signal.wait(), stream.next()).await { match result { Ok(raw_msg) => { #[cfg(feature = "metrics")] - metrics::inbound_message_count(&self.peer).inc(); + metrics::inbound_message_count(self.connection.peer_node_id()).inc(); let msg_len = raw_msg.len(); let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.freeze()); debug!( diff --git a/comms/core/src/protocol/messaging/protocol.rs b/comms/core/src/protocol/messaging/protocol.rs index fd5434cd3a..c660dae70d 100644 --- a/comms/core/src/protocol/messaging/protocol.rs +++ b/comms/core/src/protocol/messaging/protocol.rs @@ -50,6 +50,7 @@ use crate::{ ProtocolId, ProtocolNotification, }, + PeerConnection, }; const LOG_TARGET: &str = "comms::protocol::messaging"; @@ -203,7 +204,9 @@ impl MessagingProtocol { }, Some(notification) = self.proto_notification.recv() => { - self.handle_protocol_notification(notification); + if let Err(err) = self.handle_protocol_notification(notification).await { + error!(target: LOG_TARGET, "handle_protocol_notification failed: {err}"); + } }, _ = &mut shutdown_signal => { @@ -332,7 +335,8 @@ impl MessagingProtocol { msg_tx } - fn spawn_inbound_handler(&mut self, peer: NodeId, substream: Substream) { + fn spawn_inbound_handler(&mut self, conn: PeerConnection, substream: Substream) { + let peer = conn.peer_node_id().clone(); if let Some(handle) = self.active_inbound.get(&peer) { if handle.is_finished() { self.active_inbound.remove(&peer); @@ -347,7 +351,7 @@ impl MessagingProtocol { let messaging_events_tx = self.messaging_events_tx.clone(); let inbound_message_tx = self.inbound_message_tx.clone(); let inbound_messaging = InboundMessaging::new( - peer.clone(), + conn, inbound_message_tx, messaging_events_tx, self.enable_message_received_event, @@ -357,7 +361,10 @@ impl MessagingProtocol { self.active_inbound.insert(peer, handle); } - fn handle_protocol_notification(&mut self, notification: ProtocolNotification) { + async fn handle_protocol_notification( + &mut self, + notification: ProtocolNotification, + ) -> Result<(), MessagingProtocolError> { match notification.event { // Peer negotiated to speak the messaging protocol with us ProtocolEvent::NewInboundSubstream(node_id, substream) => { @@ -366,10 +373,17 @@ impl MessagingProtocol { "NewInboundSubstream for peer '{}'", node_id.short_str() ); - - self.spawn_inbound_handler(node_id, substream); + match self.connectivity.get_connection(node_id.clone()).await? { + Some(conn) => { + self.spawn_inbound_handler(conn, substream); + }, + None => { + error!(target: LOG_TARGET, "No active connection for new inbound substream for node {node_id}"); + }, + } }, } + Ok(()) } async fn ban_peer(&mut self, peer_node_id: NodeId, reason: T) { diff --git a/comms/core/src/protocol/messaging/test.rs b/comms/core/src/protocol/messaging/test.rs index 2a27e9e7ca..08d6b901c7 100644 --- a/comms/core/src/protocol/messaging/test.rs +++ b/comms/core/src/protocol/messaging/test.rs @@ -49,7 +49,6 @@ use crate::{ mocks::{create_connectivity_mock, create_peer_connection_mock_pair, ConnectivityManagerMockState}, node_id, node_identity::build_node_identity, - transport, }, types::{CommsDatabase, CommsPublicKey}, }; @@ -108,34 +107,47 @@ async fn spawn_messaging_protocol() -> ( #[tokio::test] async fn new_inbound_substream_handling() { - let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) = + let (peer_manager, _, conn_man_mock, proto_tx, outbound_msg_tx, mut inbound_msg_rx, mut events_rx, _shutdown) = spawn_messaging_protocol().await; let expected_node_id = node_id::random(); let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng); - peer_manager - .add_peer(Peer::new( - pk.clone(), - expected_node_id.clone(), - MultiaddressesWithStats::default(), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_CLIENT, - Default::default(), - Default::default(), - )) - .await - .unwrap(); + let peer1 = Peer::new( + pk.clone(), + expected_node_id.clone(), + MultiaddressesWithStats::default(), + PeerFlags::empty(), + PeerFeatures::COMMUNICATION_CLIENT, + Default::default(), + Default::default(), + ); + peer_manager.add_peer(peer1.clone()).await.unwrap(); - // Create connected memory sockets - we use each end of the connection as if they exist on different nodes - let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await; + let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng); + let peer2 = Peer::new( + pk.clone(), + expected_node_id.clone(), + MultiaddressesWithStats::default(), + PeerFlags::empty(), + PeerFeatures::COMMUNICATION_CLIENT, + Default::default(), + Default::default(), + ); - let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + let (_, conn1_state, conn2, _conn2_state) = create_peer_connection_mock_pair(peer1.clone(), peer2.clone()).await; - let mut framed_ours = MessagingProtocol::framed(stream_ours); - framed_ours.send(TEST_MSG1.clone()).await.unwrap(); + conn_man_mock.add_active_connection(conn2).await; + + let (reply_tx, _reply_rx) = oneshot::channel(); + let out_msg = OutboundMessage { + tag: MessageTag::new(), + reply: reply_tx.into(), + peer_node_id: peer1.node_id.clone(), + body: TEST_MSG1.clone(), + }; + outbound_msg_tx.send(out_msg).unwrap(); - // Notify the messaging protocol that a new substream has been established that wants to talk the messaging. - let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap(); + let stream_theirs = conn1_state.next_incoming_substream().await.unwrap(); proto_tx .send(ProtocolNotification::new( MESSAGING_PROTOCOL_ID.clone(), @@ -352,30 +364,35 @@ async fn many_concurrent_send_message_requests_that_fail() { #[tokio::test] async fn new_inbound_substream_only_single_session_permitted() { - let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, _, _shutdown) = spawn_messaging_protocol().await; + let (peer_manager, node_identity_1, conn_man_mock, proto_tx, _, mut inbound_msg_rx, _, _shutdown) = + spawn_messaging_protocol().await; let expected_node_id = node_id::random(); + let peer1 = node_identity_1.to_peer(); + let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng); - peer_manager - .add_peer(Peer::new( - pk.clone(), - expected_node_id.clone(), - MultiaddressesWithStats::default(), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_CLIENT, - Default::default(), - Default::default(), - )) - .await - .unwrap(); + let peer2 = Peer::new( + pk.clone(), + expected_node_id.clone(), + MultiaddressesWithStats::default(), + PeerFlags::empty(), + PeerFeatures::COMMUNICATION_CLIENT, + Default::default(), + Default::default(), + ); + peer_manager.add_peer(peer2.clone()).await.unwrap(); + + let (conn1, conn1_state, _, conn2_state) = create_peer_connection_mock_pair(peer1.clone(), peer2.clone()).await; + + conn_man_mock.add_active_connection(conn1).await; // Create connected memory sockets - we use each end of the connection as if they exist on different nodes - let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await; + // let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await; // Spawn a task to deal with incoming substreams tokio::spawn({ let expected_node_id = expected_node_id.clone(); async move { - while let Some(stream_theirs) = muxer_theirs.incoming_mut().next().await { + while let Some(stream_theirs) = conn2_state.next_incoming_substream().await { proto_tx .send(ProtocolNotification::new( MESSAGING_PROTOCOL_ID.clone(), @@ -388,7 +405,7 @@ async fn new_inbound_substream_only_single_session_permitted() { }); // Open first stream - let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + let stream_ours = conn1_state.open_substream().await.unwrap(); let mut framed_ours = MessagingProtocol::framed(stream_ours); framed_ours.send(TEST_MSG1.clone()).await.unwrap(); @@ -401,7 +418,7 @@ async fn new_inbound_substream_only_single_session_permitted() { assert_eq!(in_msg.body, TEST_MSG1); // Check the second stream closes immediately - let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + let stream_ours2 = conn1_state.open_substream().await.unwrap(); let mut framed_ours2 = MessagingProtocol::framed(stream_ours2); // Check that it eventually exits. The first send will initiate the substream and send. Once the other side closes @@ -431,7 +448,7 @@ async fn new_inbound_substream_only_single_session_permitted() { framed_ours.close().await.unwrap(); // Open another one for messaging - let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + let stream_ours = conn1_state.open_substream().await.unwrap(); let mut framed_ours = MessagingProtocol::framed(stream_ours); framed_ours.send(TEST_MSG1.clone()).await.unwrap();