Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(comms): ensure that inbound messaging terminates on disconnect #6653

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions comms/core/src/protocol/messaging/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub enum MessagingProtocolError {
PeerConnectionError(#[from] PeerConnectionError),
#[error("Failed to dial peer: {0}")]
PeerDialFailed(ConnectivityError),
#[error("Connectivity error: {0}")]
ConnectivityError(#[from] ConnectivityError),
#[error("IO Error: {0}")]
Io(io::Error),
#[error("Sender error: {0}")]
Expand Down
13 changes: 7 additions & 6 deletions comms/core/src/protocol/messaging/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ use tokio::{
#[cfg(feature = "metrics")]
use super::metrics;
use super::{MessagingEvent, MessagingProtocol};
use crate::{message::InboundMessage, peer_manager::NodeId};
use crate::{message::InboundMessage, PeerConnection};

const LOG_TARGET: &str = "comms::protocol::messaging::inbound";

/// Inbound messaging actor. This is lazily spawned per peer when a peer requests a messaging session.
pub struct InboundMessaging {
peer: NodeId,
connection: PeerConnection,
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
Expand All @@ -48,14 +48,14 @@ pub struct InboundMessaging {

impl InboundMessaging {
pub fn new(
peer: NodeId,
connection: PeerConnection,
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
shutdown_signal: ShutdownSignal,
) -> Self {
Self {
peer,
connection,
inbound_message_tx,
messaging_events_tx,
enable_message_received_event,
Expand All @@ -65,7 +65,7 @@ impl InboundMessaging {

pub async fn run<S>(mut self, socket: S)
where S: AsyncRead + AsyncWrite + Unpin {
let peer = &self.peer;
let peer = self.connection.peer_node_id();
#[cfg(feature = "metrics")]
metrics::num_sessions().inc();
debug!(
Expand All @@ -75,13 +75,14 @@ impl InboundMessaging {
);

let stream = MessagingProtocol::framed(socket);
let stream = stream.take_until(self.connection.on_disconnect());
tokio::pin!(stream);

while let Either::Right((Some(result), _)) = future::select(self.shutdown_signal.wait(), stream.next()).await {
match result {
Ok(raw_msg) => {
#[cfg(feature = "metrics")]
metrics::inbound_message_count(&self.peer).inc();
metrics::inbound_message_count(self.connection.peer_node_id()).inc();
let msg_len = raw_msg.len();
let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.freeze());
debug!(
Expand Down
26 changes: 20 additions & 6 deletions comms/core/src/protocol/messaging/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use crate::{
ProtocolId,
ProtocolNotification,
},
PeerConnection,
};

const LOG_TARGET: &str = "comms::protocol::messaging";
Expand Down Expand Up @@ -203,7 +204,9 @@ impl MessagingProtocol {
},

Some(notification) = self.proto_notification.recv() => {
self.handle_protocol_notification(notification);
if let Err(err) = self.handle_protocol_notification(notification).await {
error!(target: LOG_TARGET, "handle_protocol_notification failed: {err}");
}
},

_ = &mut shutdown_signal => {
Expand Down Expand Up @@ -332,7 +335,8 @@ impl MessagingProtocol {
msg_tx
}

fn spawn_inbound_handler(&mut self, peer: NodeId, substream: Substream) {
fn spawn_inbound_handler(&mut self, conn: PeerConnection, substream: Substream) {
let peer = conn.peer_node_id().clone();
if let Some(handle) = self.active_inbound.get(&peer) {
if handle.is_finished() {
self.active_inbound.remove(&peer);
Expand All @@ -347,7 +351,7 @@ impl MessagingProtocol {
let messaging_events_tx = self.messaging_events_tx.clone();
let inbound_message_tx = self.inbound_message_tx.clone();
let inbound_messaging = InboundMessaging::new(
peer.clone(),
conn,
inbound_message_tx,
messaging_events_tx,
self.enable_message_received_event,
Expand All @@ -357,7 +361,10 @@ impl MessagingProtocol {
self.active_inbound.insert(peer, handle);
}

fn handle_protocol_notification(&mut self, notification: ProtocolNotification<Substream>) {
async fn handle_protocol_notification(
&mut self,
notification: ProtocolNotification<Substream>,
) -> Result<(), MessagingProtocolError> {
match notification.event {
// Peer negotiated to speak the messaging protocol with us
ProtocolEvent::NewInboundSubstream(node_id, substream) => {
Expand All @@ -366,10 +373,17 @@ impl MessagingProtocol {
"NewInboundSubstream for peer '{}'",
node_id.short_str()
);

self.spawn_inbound_handler(node_id, substream);
match self.connectivity.get_connection(node_id.clone()).await? {
Some(conn) => {
self.spawn_inbound_handler(conn, substream);
},
None => {
error!(target: LOG_TARGET, "No active connection for new inbound substream for node {node_id}");
},
}
},
}
Ok(())
}

async fn ban_peer<T: Display>(&mut self, peer_node_id: NodeId, reason: T) {
Expand Down
95 changes: 56 additions & 39 deletions comms/core/src/protocol/messaging/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ use crate::{
mocks::{create_connectivity_mock, create_peer_connection_mock_pair, ConnectivityManagerMockState},
node_id,
node_identity::build_node_identity,
transport,
},
types::{CommsDatabase, CommsPublicKey},
};
Expand Down Expand Up @@ -108,34 +107,47 @@ async fn spawn_messaging_protocol() -> (

#[tokio::test]
async fn new_inbound_substream_handling() {
let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) =
let (peer_manager, _, conn_man_mock, proto_tx, outbound_msg_tx, mut inbound_msg_rx, mut events_rx, _shutdown) =
spawn_messaging_protocol().await;

let expected_node_id = node_id::random();
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
peer_manager
.add_peer(Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
))
.await
.unwrap();
let peer1 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);
peer_manager.add_peer(peer1.clone()).await.unwrap();

// Create connected memory sockets - we use each end of the connection as if they exist on different nodes
let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await;
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
let peer2 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);

let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let (_, conn1_state, conn2, _conn2_state) = create_peer_connection_mock_pair(peer1.clone(), peer2.clone()).await;

let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();
conn_man_mock.add_active_connection(conn2).await;

let (reply_tx, _reply_rx) = oneshot::channel();
let out_msg = OutboundMessage {
tag: MessageTag::new(),
reply: reply_tx.into(),
peer_node_id: peer1.node_id.clone(),
body: TEST_MSG1.clone(),
};
outbound_msg_tx.send(out_msg).unwrap();

// Notify the messaging protocol that a new substream has been established that wants to talk the messaging.
let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap();
let stream_theirs = conn1_state.next_incoming_substream().await.unwrap();
proto_tx
.send(ProtocolNotification::new(
MESSAGING_PROTOCOL_ID.clone(),
Expand Down Expand Up @@ -352,30 +364,35 @@ async fn many_concurrent_send_message_requests_that_fail() {

#[tokio::test]
async fn new_inbound_substream_only_single_session_permitted() {
let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, _, _shutdown) = spawn_messaging_protocol().await;
let (peer_manager, node_identity_1, conn_man_mock, proto_tx, _, mut inbound_msg_rx, _, _shutdown) =
spawn_messaging_protocol().await;

let expected_node_id = node_id::random();
let peer1 = node_identity_1.to_peer();

let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
peer_manager
.add_peer(Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
))
.await
.unwrap();
let peer2 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);
peer_manager.add_peer(peer2.clone()).await.unwrap();

let (conn1, conn1_state, _, conn2_state) = create_peer_connection_mock_pair(peer1.clone(), peer2.clone()).await;

conn_man_mock.add_active_connection(conn1).await;

// Create connected memory sockets - we use each end of the connection as if they exist on different nodes
let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await;
// let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await;
// Spawn a task to deal with incoming substreams
tokio::spawn({
let expected_node_id = expected_node_id.clone();
async move {
while let Some(stream_theirs) = muxer_theirs.incoming_mut().next().await {
while let Some(stream_theirs) = conn2_state.next_incoming_substream().await {
proto_tx
.send(ProtocolNotification::new(
MESSAGING_PROTOCOL_ID.clone(),
Expand All @@ -388,7 +405,7 @@ async fn new_inbound_substream_only_single_session_permitted() {
});

// Open first stream
let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let stream_ours = conn1_state.open_substream().await.unwrap();
let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();

Expand All @@ -401,7 +418,7 @@ async fn new_inbound_substream_only_single_session_permitted() {
assert_eq!(in_msg.body, TEST_MSG1);

// Check the second stream closes immediately
let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let stream_ours2 = conn1_state.open_substream().await.unwrap();

let mut framed_ours2 = MessagingProtocol::framed(stream_ours2);
// Check that it eventually exits. The first send will initiate the substream and send. Once the other side closes
Expand Down Expand Up @@ -431,7 +448,7 @@ async fn new_inbound_substream_only_single_session_permitted() {
framed_ours.close().await.unwrap();

// Open another one for messaging
let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let stream_ours = conn1_state.open_substream().await.unwrap();
let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();

Expand Down
Loading