Skip to content

Commit

Permalink
fix(comms): ensure that inbound messaging terminates on disconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbondi committed Oct 29, 2024
1 parent f802743 commit 6629a26
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 54 deletions.
16 changes: 13 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions comms/core/src/protocol/messaging/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub enum MessagingProtocolError {
PeerConnectionError(#[from] PeerConnectionError),
#[error("Failed to dial peer: {0}")]
PeerDialFailed(ConnectivityError),
#[error("Connectivity error: {0}")]
ConnectivityError(#[from] ConnectivityError),
#[error("IO Error: {0}")]
Io(io::Error),
#[error("Sender error: {0}")]
Expand Down
13 changes: 7 additions & 6 deletions comms/core/src/protocol/messaging/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ use tokio::{
#[cfg(feature = "metrics")]
use super::metrics;
use super::{MessagingEvent, MessagingProtocol};
use crate::{message::InboundMessage, peer_manager::NodeId};
use crate::{message::InboundMessage, PeerConnection};

const LOG_TARGET: &str = "comms::protocol::messaging::inbound";

/// Inbound messaging actor. This is lazily spawned per peer when a peer requests a messaging session.
pub struct InboundMessaging {
peer: NodeId,
connection: PeerConnection,
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
Expand All @@ -48,14 +48,14 @@ pub struct InboundMessaging {

impl InboundMessaging {
pub fn new(
peer: NodeId,
connection: PeerConnection,
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
shutdown_signal: ShutdownSignal,
) -> Self {
Self {
peer,
connection,
inbound_message_tx,
messaging_events_tx,
enable_message_received_event,
Expand All @@ -65,7 +65,7 @@ impl InboundMessaging {

pub async fn run<S>(mut self, socket: S)
where S: AsyncRead + AsyncWrite + Unpin {
let peer = &self.peer;
let peer = self.connection.peer_node_id();
#[cfg(feature = "metrics")]
metrics::num_sessions().inc();
debug!(
Expand All @@ -75,13 +75,14 @@ impl InboundMessaging {
);

let stream = MessagingProtocol::framed(socket);
let stream = stream.take_until(self.connection.on_disconnect());
tokio::pin!(stream);

while let Either::Right((Some(result), _)) = future::select(self.shutdown_signal.wait(), stream.next()).await {
match result {
Ok(raw_msg) => {
#[cfg(feature = "metrics")]
metrics::inbound_message_count(&self.peer).inc();
metrics::inbound_message_count(self.connection.peer_node_id()).inc();
let msg_len = raw_msg.len();
let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.freeze());
debug!(
Expand Down
26 changes: 20 additions & 6 deletions comms/core/src/protocol/messaging/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use crate::{
ProtocolId,
ProtocolNotification,
},
PeerConnection,
};

const LOG_TARGET: &str = "comms::protocol::messaging";
Expand Down Expand Up @@ -203,7 +204,9 @@ impl MessagingProtocol {
},

Some(notification) = self.proto_notification.recv() => {
self.handle_protocol_notification(notification);
if let Err(err) = self.handle_protocol_notification(notification).await {
error!(target: LOG_TARGET, "handle_protocol_notification failed: {err}");
}
},

_ = &mut shutdown_signal => {
Expand Down Expand Up @@ -332,7 +335,8 @@ impl MessagingProtocol {
msg_tx
}

fn spawn_inbound_handler(&mut self, peer: NodeId, substream: Substream) {
fn spawn_inbound_handler(&mut self, conn: PeerConnection, substream: Substream) {
let peer = conn.peer_node_id().clone();
if let Some(handle) = self.active_inbound.get(&peer) {
if handle.is_finished() {
self.active_inbound.remove(&peer);
Expand All @@ -347,7 +351,7 @@ impl MessagingProtocol {
let messaging_events_tx = self.messaging_events_tx.clone();
let inbound_message_tx = self.inbound_message_tx.clone();
let inbound_messaging = InboundMessaging::new(
peer.clone(),
conn,
inbound_message_tx,
messaging_events_tx,
self.enable_message_received_event,
Expand All @@ -357,7 +361,10 @@ impl MessagingProtocol {
self.active_inbound.insert(peer, handle);
}

fn handle_protocol_notification(&mut self, notification: ProtocolNotification<Substream>) {
async fn handle_protocol_notification(
&mut self,
notification: ProtocolNotification<Substream>,
) -> Result<(), MessagingProtocolError> {
match notification.event {
// Peer negotiated to speak the messaging protocol with us
ProtocolEvent::NewInboundSubstream(node_id, substream) => {
Expand All @@ -366,10 +373,17 @@ impl MessagingProtocol {
"NewInboundSubstream for peer '{}'",
node_id.short_str()
);

self.spawn_inbound_handler(node_id, substream);
match self.connectivity.get_connection(node_id.clone()).await? {
Some(conn) => {
self.spawn_inbound_handler(conn, substream);
},
None => {
error!(target: LOG_TARGET, "No active connection for new inbound substream for node {node_id}");
},
}
},
}
Ok(())
}

async fn ban_peer<T: Display>(&mut self, peer_node_id: NodeId, reason: T) {
Expand Down
95 changes: 56 additions & 39 deletions comms/core/src/protocol/messaging/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ use crate::{
mocks::{create_connectivity_mock, create_peer_connection_mock_pair, ConnectivityManagerMockState},
node_id,
node_identity::build_node_identity,
transport,
},
types::{CommsDatabase, CommsPublicKey},
};
Expand Down Expand Up @@ -108,34 +107,47 @@ async fn spawn_messaging_protocol() -> (

#[tokio::test]
async fn new_inbound_substream_handling() {
let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) =
let (peer_manager, _, conn_man_mock, proto_tx, outbound_msg_tx, mut inbound_msg_rx, mut events_rx, _shutdown) =
spawn_messaging_protocol().await;

let expected_node_id = node_id::random();
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
peer_manager
.add_peer(Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
))
.await
.unwrap();
let peer1 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);
peer_manager.add_peer(peer1.clone()).await.unwrap();

// Create connected memory sockets - we use each end of the connection as if they exist on different nodes
let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await;
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
let peer2 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);

let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let (_, conn1_state, conn2, _conn2_state) = create_peer_connection_mock_pair(peer1.clone(), peer2.clone()).await;

let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();
conn_man_mock.add_active_connection(conn2).await;

let (reply_tx, _reply_rx) = oneshot::channel();
let out_msg = OutboundMessage {
tag: MessageTag::new(),
reply: reply_tx.into(),
peer_node_id: peer1.node_id.clone(),
body: TEST_MSG1.clone(),
};
outbound_msg_tx.send(out_msg).unwrap();

// Notify the messaging protocol that a new substream has been established that wants to talk the messaging.
let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap();
let stream_theirs = conn1_state.next_incoming_substream().await.unwrap();
proto_tx
.send(ProtocolNotification::new(
MESSAGING_PROTOCOL_ID.clone(),
Expand Down Expand Up @@ -352,30 +364,35 @@ async fn many_concurrent_send_message_requests_that_fail() {

#[tokio::test]
async fn new_inbound_substream_only_single_session_permitted() {
let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, _, _shutdown) = spawn_messaging_protocol().await;
let (peer_manager, node_identity_1, conn_man_mock, proto_tx, _, mut inbound_msg_rx, _, _shutdown) =
spawn_messaging_protocol().await;

let expected_node_id = node_id::random();
let peer1 = node_identity_1.to_peer();

let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
peer_manager
.add_peer(Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
))
.await
.unwrap();
let peer2 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);
peer_manager.add_peer(peer2.clone()).await.unwrap();

let (conn1, conn1_state, _, conn2_state) = create_peer_connection_mock_pair(peer1.clone(), peer2.clone()).await;

conn_man_mock.add_active_connection(conn1).await;

// Create connected memory sockets - we use each end of the connection as if they exist on different nodes
let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await;
// let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await;
// Spawn a task to deal with incoming substreams
tokio::spawn({
let expected_node_id = expected_node_id.clone();
async move {
while let Some(stream_theirs) = muxer_theirs.incoming_mut().next().await {
while let Some(stream_theirs) = conn2_state.next_incoming_substream().await {
proto_tx
.send(ProtocolNotification::new(
MESSAGING_PROTOCOL_ID.clone(),
Expand All @@ -388,7 +405,7 @@ async fn new_inbound_substream_only_single_session_permitted() {
});

// Open first stream
let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let stream_ours = conn1_state.open_substream().await.unwrap();
let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();

Expand All @@ -401,7 +418,7 @@ async fn new_inbound_substream_only_single_session_permitted() {
assert_eq!(in_msg.body, TEST_MSG1);

// Check the second stream closes immediately
let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let stream_ours2 = conn1_state.open_substream().await.unwrap();

let mut framed_ours2 = MessagingProtocol::framed(stream_ours2);
// Check that it eventually exits. The first send will initiate the substream and send. Once the other side closes
Expand Down Expand Up @@ -431,7 +448,7 @@ async fn new_inbound_substream_only_single_session_permitted() {
framed_ours.close().await.unwrap();

// Open another one for messaging
let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let stream_ours = conn1_state.open_substream().await.unwrap();
let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();

Expand Down

0 comments on commit 6629a26

Please sign in to comment.