From 8449785864d90f17216b2ade9a499905e578a816 Mon Sep 17 00:00:00 2001 From: Flawed <33593723+ff14wed@users.noreply.github.com> Date: Sun, 12 Mar 2023 07:17:25 -0700 Subject: [PATCH] Ensure proper cleanup of subscribers In the case of server shutdown, this change ensures that we abort the tasks responsible for handling client connections before fully shutting down. This allows the named pipe connections to be dropped properly. (Fixes #13) In the case of a client disconnection at any point in the message loop for the subscriber, the server will ensure that the disconnect is handled instead of exiting the subscriber handler immediately. This fixes a rare issue wherein the server does not exit when the last client disconnects because the disconnection happened while the server is sending a message to the subscriber. --- src/server.rs | 562 +++++++++++++++++++++++++------------------------- 1 file changed, 285 insertions(+), 277 deletions(-) diff --git a/src/server.rs b/src/server.rs index 6bb0287..7936da4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,6 +9,7 @@ use futures::{SinkExt, Stream, StreamExt}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::{mpsc, Mutex}; +use tokio::task::JoinSet; use tokio_util::codec::Framed; use once_cell::sync::OnceCell; @@ -200,8 +201,8 @@ impl Server { } } - /// Handle the message from subscriber and send a success/failure response back - async fn handle_subscriber_message( + /// Handle the payload from subscriber and send a success/failure response back + async fn handle_payload_from_subscriber( payload: rpc::Payload, subscriber: &mut Subscriber, payload_handler: &F, @@ -245,37 +246,19 @@ impl Server { Ok(()) } - /// Handle an individual subscriber - async fn handle_subscriber( + /// Subscriber message loop. Returns an error if it has trouble writing + /// to the connection. + /// Returns true if it returned because of an Exit request. Returns false + /// if the connection was closed naturally. + async fn subscriber_msg_loop( &self, - stream: impl AsyncRead + AsyncWrite + std::marker::Unpin, + mut subscriber: &mut Subscriber, payload_handler: F, - ) -> Result<()> + ) -> Result where + T: AsyncRead + AsyncWrite + std::marker::Unpin, F: Fn(rpc::Payload) -> Result<()>, { - let codec = rpc::PayloadCodec::new(); - let frames = Framed::new(stream, codec); - - let (id, rx) = self.state.lock().await.new_subscriber(); - let hello_string = self.state.lock().await.server_hello_string(); - - let mut subscriber = Subscriber { id, frames, rx }; - - subscriber - .frames - .send( - rpc::Payload { - op: rpc::MessageOps::Debug, - ctx: 9000, - data: hello_string.into_bytes(), - } - .into(), - ) - .await?; - - info!("New subscriber connected: {}", subscriber.id); - let ping_payload = rpc::Payload { op: rpc::MessageOps::Ping, ctx: 0, @@ -288,7 +271,7 @@ impl Server { // Process incoming messages until our stream is exhausted by a disconnect. while let Some(result) = subscriber.next().await { match result { - // A request was received from the current user + // A request was received from the current subscriber Ok(Message::Request(payload)) => match payload.op { rpc::MessageOps::Ping => { subscriber.frames.send(ping_payload.clone()).await?; @@ -296,7 +279,7 @@ impl Server { rpc::MessageOps::Exit => { info!("Shutting down server because Exit payload received"); self.shutdown().await; - return Ok(()); + return Ok(true); } rpc::MessageOps::Option => { filter = payload.ctx; @@ -310,8 +293,12 @@ impl Server { .await? } _ => { - Self::handle_subscriber_message(payload, &mut subscriber, &payload_handler) - .await?; + Self::handle_payload_from_subscriber( + payload, + &mut subscriber, + &payload_handler, + ) + .await?; } }, @@ -330,7 +317,57 @@ impl Server { } } - // If this section is reached it means that the subscriber was disconnected! + Ok(false) + } + + /// Handle an individual subscriber + async fn handle_subscriber( + &self, + stream: impl AsyncRead + AsyncWrite + std::marker::Unpin, + payload_handler: F, + ) -> Result<()> + where + F: Fn(rpc::Payload) -> Result<()>, + { + let codec = rpc::PayloadCodec::new(); + let frames = Framed::new(stream, codec); + + let (id, rx) = self.state.lock().await.new_subscriber(); + let hello_string = self.state.lock().await.server_hello_string(); + + let mut subscriber = Subscriber { id, frames, rx }; + + subscriber + .frames + .send( + rpc::Payload { + op: rpc::MessageOps::Debug, + ctx: 9000, + data: hello_string.into_bytes(), + } + .into(), + ) + .await?; + + info!("New subscriber connected: {}", subscriber.id); + + match self + .subscriber_msg_loop(&mut subscriber, payload_handler) + .await + { + Ok(server_exit) => { + if server_exit { + return Ok(()); + } + } + Err(e) => error!( + "Disconecting subscriber {} because of error: {}", + subscriber.id, e + ), + } + + // If this section is reached it means that the subscriber was + // disconnected one way or another. { info!("Subscriber disconnected: {}", subscriber.id); let mut state = self.state.lock().await; @@ -368,6 +405,8 @@ impl Server { trigger.cancel(); }); + let mut subscriber_set = JoinSet::new(); + // Wait on subscribers and create a new loop task for each new // connection while let Some(result) = incoming.next().await { @@ -375,7 +414,7 @@ impl Server { Ok(stream) => { let handler = payload_handler.clone(); let self_clone = self.clone(); - tokio::spawn(async move { + subscriber_set.spawn(async move { if let Err(e) = self_clone.handle_subscriber(stream, handler).await { error!("Error occurred when processing stream = {:?}", e); } @@ -384,6 +423,9 @@ impl Server { Err(e) => error!("Unable to connect to subscriber: {}", e), } } + info!("Shutting down subscriber handlers"); + subscriber_set.shutdown().await; + info!("Server shut down!"); Ok(()) } @@ -391,10 +433,12 @@ impl Server { #[cfg(test)] mod tests { + use std::sync::atomic::{AtomicU32, Ordering}; + use super::*; use ntest::timeout; use rand::Rng; - use tokio::select; + use tokio::{select, task::JoinHandle}; #[test] fn test_individual_packet_filters() { @@ -456,33 +500,32 @@ mod tests { } } - #[tokio::test(flavor = "multi_thread")] - #[timeout(10_000)] - async fn test_combined_broadcast_filters() { + /// Creates and runs a server. Returns the server, the pipe name, + /// and the JoinHandle of the server.run() task. + async fn run_server() -> (Server, String, JoinHandle<()>) { let server = Server::new(); let test_id: u16 = rand::thread_rng().gen(); let pipe_name = format!(r"\\.\pipe\deucalion-test-{}", test_id); - let pipe_name_clone = pipe_name.clone(); let server_clone = server.clone(); - tokio::spawn(async move { + let pipe_name_clone = pipe_name.clone(); + let server_handle = tokio::spawn(async move { server_clone .run(pipe_name_clone, move |_: rpc::Payload| Ok(())) .await .expect("Server should not fail to run"); }); - // Give the server some time to start tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let subscriber = Endpoint::connect(&pipe_name) - .await - .expect("Failed to connect subscriber to server"); - - let codec = rpc::PayloadCodec::new(); - let mut frames = Framed::new(subscriber, codec); + (server, pipe_name, server_handle) + } + async fn handle_server_hello(frames: &mut Framed) + where + T: AsyncRead + AsyncWrite + std::marker::Unpin, + { // Handle the SERVER_HELLO message let message = frames.next().await.unwrap(); if let Ok(payload) = message { @@ -490,80 +533,87 @@ mod tests { } else { panic!("Did not properly receive Server Hello"); } + } - let filter = BroadcastFilter::AllowChatRecv as u32 - | BroadcastFilter::AllowChatSend as u32 - | BroadcastFilter::AllowZoneRecv as u32; + #[tokio::test(flavor = "multi_thread")] + #[timeout(10_000)] + async fn test_combined_broadcast_filters() { + let (server, pipe_name, server_handle) = run_server().await; - // Send option - frames - .send(rpc::Payload { - op: rpc::MessageOps::Option, - ctx: filter, - data: Vec::new(), - }) - .await - .unwrap(); + let test_handle = tokio::spawn(async move { + let subscriber = Endpoint::connect(&pipe_name) + .await + .expect("Failed to connect subscriber to server"); - let message = frames.next().await.unwrap(); - if let Ok(payload) = message { - assert_eq!(payload.op, rpc::MessageOps::Debug); - assert_eq!( - String::from_utf8(payload.data).unwrap(), - "Packet filters set: 0b00100110", - ); - } else { - panic!("Did not properly receive Server Hello"); - } + let codec = rpc::PayloadCodec::new(); + let mut frames = Framed::new(subscriber, codec); - let configurations = vec![ - (rpc::MessageOps::Recv, 0, false), - (rpc::MessageOps::Recv, 1, true), - (rpc::MessageOps::Recv, 2, true), - (rpc::MessageOps::Send, 0, false), - (rpc::MessageOps::Send, 1, false), - (rpc::MessageOps::Send, 2, true), - (rpc::MessageOps::Recv, 100, false), - ]; + handle_server_hello(&mut frames).await; - for (op, ctx, should_be_allowed) in configurations { - server - .broadcast(rpc::Payload { - op, - ctx, + let filter = BroadcastFilter::AllowChatRecv as u32 + | BroadcastFilter::AllowChatSend as u32 + | BroadcastFilter::AllowZoneRecv as u32; + + // Send option + frames + .send(rpc::Payload { + op: rpc::MessageOps::Option, + ctx: filter, data: Vec::new(), }) - .await; + .await + .unwrap(); + + let message = frames.next().await.unwrap(); + if let Ok(payload) = message { + assert_eq!(payload.op, rpc::MessageOps::Debug); + assert_eq!( + String::from_utf8(payload.data).unwrap(), + "Packet filters set: 0b00100110", + ); + } else { + panic!("Did not properly receive Server Hello"); + } - select! { - data = frames.next() => { - assert_eq!(should_be_allowed, true, "packet should be filtered: {:?}", data) - } - _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => { - assert_eq!(should_be_allowed, false, "packet should not be filtered: {:?}: {}", op, ctx) + let configurations = vec![ + (rpc::MessageOps::Recv, 0, false), + (rpc::MessageOps::Recv, 1, true), + (rpc::MessageOps::Recv, 2, true), + (rpc::MessageOps::Send, 0, false), + (rpc::MessageOps::Send, 1, false), + (rpc::MessageOps::Send, 2, true), + (rpc::MessageOps::Recv, 100, false), + ]; + + for (op, ctx, should_be_allowed) in configurations { + server + .broadcast(rpc::Payload { + op, + ctx, + data: Vec::new(), + }) + .await; + + select! { + data = frames.next() => { + assert_eq!(should_be_allowed, true, "packet should be filtered: {:?}", data); + } + _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => { + assert_eq!(should_be_allowed, false, "packet should not be filtered: {:?}: {}", op, ctx) + } } } - } + }); + + test_handle.await.expect("Test failed with assertion"); + server_handle.abort(); + let _ = server_handle.await; } #[tokio::test(flavor = "multi_thread")] #[timeout(10_000)] async fn test_server_shutdown() { - let server = Server::new(); - - let test_id: u16 = rand::thread_rng().gen(); - let pipe_name = format!(r"\\.\pipe\deucalion-test-{}", test_id); - let pipe_name_clone = pipe_name.clone(); - - let server_task = tokio::spawn(async move { - server - .run(pipe_name_clone, move |_: rpc::Payload| Ok(())) - .await - .expect("Server should not fail to run"); - }); - - // Give the server some time to start - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + let (_, pipe_name, server_handle) = run_server().await; let subscriber = Endpoint::connect(&pipe_name) .await @@ -572,13 +622,7 @@ mod tests { let codec = rpc::PayloadCodec::new(); let mut frames = Framed::new(subscriber, codec); - // Handle the SERVER_HELLO message - let message = frames.next().await.unwrap(); - if let Ok(payload) = message { - assert_eq!(payload.ctx, 9000); - } else { - panic!("Did not properly receive Server Hello"); - } + handle_server_hello(&mut frames).await; // Send exit frames @@ -591,94 +635,78 @@ mod tests { .unwrap(); // Wait on the server to shut down - let _ = server_task.await; + let _ = server_handle.await; } + /// The server is expected to shut down when the last subscriber + /// disconnects. The last subscriber must hold the connection long enough to + /// get the SERVER_HELLO for this to be the case. #[tokio::test(flavor = "multi_thread")] #[timeout(10_000)] async fn test_subscriber_disconnect() { - let server = Server::new(); - - let test_id: u16 = rand::thread_rng().gen(); - let pipe_name = format!(r"\\.\pipe\deucalion-test-{}", test_id); - let pipe_name_clone = pipe_name.clone(); - - let server_task = tokio::spawn(async move { - server - .run(pipe_name_clone, move |_: rpc::Payload| Ok(())) - .await - .expect("Server should not fail to run"); - }); - - // Give the server some time to start - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + let (_, pipe_name, server_handle) = run_server().await; let subscriber = Endpoint::connect(&pipe_name) .await .expect("Failed to connect subscriber to server"); - let codec = rpc::PayloadCodec::new(); let mut frames = Framed::new(subscriber, codec); - // Handle the SERVER_HELLO message - let message = frames.next().await.unwrap(); - if let Ok(payload) = message { - assert_eq!(payload.ctx, 9000); - } else { - panic!("Did not properly receive Server Hello"); - } + handle_server_hello(&mut frames).await; + + let subscriber_handle = + tokio::spawn(async move { while let Some(_) = frames.next().await {} }); // Disconnect the subscriber forcefully - drop(frames); + subscriber_handle.abort(); + let _ = subscriber_handle.await; - // Wait on the server to shut down - let _ = server_task.await; + let _ = server_handle.await; } #[tokio::test(flavor = "multi_thread")] #[timeout(10_000)] async fn test_second_subscriber_disconnect() { - let server = Server::new(); + let (server, pipe_name, server_handle) = run_server().await; - let test_id: u16 = rand::thread_rng().gen(); - let pipe_name = format!(r"\\.\pipe\deucalion-test-{}", test_id); - let pipe_name_clone = pipe_name.clone(); - - let server_clone = server.clone(); - tokio::spawn(async move { - server_clone - .run(pipe_name_clone, move |_: rpc::Payload| Ok(())) - .await - .expect("Server should not fail to run"); - }); + let num_received = Arc::new(AtomicU32::new(0)); - // Give the server some time to start - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + let num_received_clone = num_received.clone(); + let pipe_name_clone = pipe_name.clone(); - let subscriber = Endpoint::connect(&pipe_name) + let subscriber = Endpoint::connect(&pipe_name_clone) .await .expect("Failed to connect subscriber to server"); let codec = rpc::PayloadCodec::new(); let mut frames = Framed::new(subscriber, codec); - // Handle the SERVER_HELLO message - let message = frames.next().await.unwrap(); - if let Ok(payload) = message { - assert_eq!(payload.ctx, 9000); - } else { - panic!("Did not properly receive Server Hello"); - } + handle_server_hello(&mut frames).await; + + let subscriber_handle = tokio::spawn(async move { + // Test that every packet was received in order + while let Some(result) = frames.next().await { + let payload = result.unwrap(); + let num_received_val = num_received_clone.fetch_add(1, Ordering::SeqCst); + assert_eq!( + payload.ctx, num_received_val, + "Received data from pipe does not match expected index!" + ); + } + }); - // Create and quickly drop the second connection - { - let subscriber2 = Endpoint::connect(&pipe_name) + // Create and quickly drop the second subscriber + let pipe_name_clone = pipe_name.clone(); + let second_subscriber = tokio::spawn(async move { + let subscriber = Endpoint::connect(&pipe_name_clone) .await .expect("Failed to connect subscriber to server"); - let codec2 = rpc::PayloadCodec::new(); - let _frames2 = Framed::new(subscriber2, codec2); - } + let codec = rpc::PayloadCodec::new(); + let mut frames = Framed::new(subscriber, codec); + while let Some(_) = frames.next().await {} + }); + second_subscriber.abort(); // Send two packets for i in 0..2 { @@ -694,24 +722,24 @@ mod tests { .await; } - let mut num_received = 0u32; - // Test that every packet was received in order - loop { - select! { - data = frames.next() => { - let payload = data.expect("Expected data from frames!").unwrap(); - assert_eq!( - payload.ctx, num_received, - "Received data from pipe does not match expected index!" - ); - num_received += 1; - } - _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => { - assert_eq!(num_received, 2, "two packets should be received"); - return; - } + // Give some time for the subscriber to process the messages + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + subscriber_handle.abort(); + if let Err(e) = subscriber_handle.await { + if !e.is_cancelled() { + panic!("Test failed with assertion: {}", e); } } + + let num_received_val = num_received.load(Ordering::SeqCst); + assert_eq!( + num_received_val, 2, + "two packets should be received by the subscriber" + ); + + server_handle.abort(); + let _ = server_handle.await; } /// A test to ensure the named pipe can handle a lot of data sent through @@ -719,59 +747,37 @@ mod tests { #[tokio::test(flavor = "multi_thread")] #[timeout(10_000)] async fn named_pipe_load_test() { - let server = Server::new(); - - let test_id: u16 = rand::thread_rng().gen(); - let pipe_name = format!(r"\\.\pipe\deucalion-test-{}", test_id); - let pipe_name_clone = pipe_name.clone(); + let (server, pipe_name, server_handle) = run_server().await; - let server_clone = server.clone(); - tokio::spawn(async move { - server_clone - .run(pipe_name_clone, move |_: rpc::Payload| Ok(())) + let test_handle = tokio::spawn(async move { + let subscriber = Endpoint::connect(&pipe_name) .await - .expect("Server should not fail to run"); - }); - - // Give the server some time to start - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + .expect("Failed to connect subscriber to server"); - let subscriber = Endpoint::connect(&pipe_name) - .await - .expect("Failed to connect subscriber to server"); + let codec = rpc::PayloadCodec::new(); + let mut frames = Framed::new(subscriber, codec); - let codec = rpc::PayloadCodec::new(); - let mut frames = Framed::new(subscriber, codec); + handle_server_hello(&mut frames).await; - // Handle the SERVER_HELLO message - let message = frames.next().await.unwrap(); - if let Ok(payload) = message { - assert_eq!(payload.ctx, 9000); - } else { - panic!("Did not properly receive Server Hello"); - } + // Synchronously send many packets before the subscriber can process them + const NUM_PACKETS: u32 = 10000; + for i in 0..NUM_PACKETS { + let mut dummy_data = Vec::from([0u8; 5000]); + rand::thread_rng().fill(&mut dummy_data[..]); - // Synchronously send many packets before the subscriber can process them - const NUM_PACKETS: u32 = 10000; - for i in 0..NUM_PACKETS { - let mut dummy_data = Vec::from([0u8; 5000]); - rand::thread_rng().fill(&mut dummy_data[..]); - - server - .broadcast(rpc::Payload { - op: rpc::MessageOps::Debug, - ctx: i, - data: dummy_data, - }) - .await; - } + server + .broadcast(rpc::Payload { + op: rpc::MessageOps::Debug, + ctx: i, + data: dummy_data, + }) + .await; + } - // Test that every packet was received in order - let mut num_received = 0u32; - while let Some(result) = frames.next().await { - match result { - // A request was received from the current user - Ok(payload) => { + // Test that every packet was received in order + let mut num_received = 0u32; + while let Some(result) = frames.next().await { + if let Ok(payload) = result { assert_eq!( payload.ctx, num_received, "Received data from pipe does not match expected index!" @@ -781,9 +787,13 @@ mod tests { return; } } - _ => (), } - } + }); + + test_handle.await.expect("Test failed with assertion"); + + server_handle.abort(); + let _ = server_handle.await; } /// A test to ensure the server remains stable even when creating and @@ -791,27 +801,20 @@ mod tests { #[tokio::test(flavor = "multi_thread")] #[timeout(10_000)] async fn early_disconnection_stress_test() { - let server = Server::new(); - - let test_id: u16 = rand::thread_rng().gen(); - let pipe_name = format!(r"\\.\pipe\deucalion-test-{}", test_id); - let pipe_name_clone = pipe_name.clone(); - - let server_clone = server.clone(); - tokio::spawn(async move { - server_clone - .run(pipe_name_clone, move |_: rpc::Payload| Ok(())) - .await - .expect("Server should not fail to run"); - }); + let (server, pipe_name, server_handle) = run_server().await; // Create and quickly drop these connections for _ in 0..100 { - // If the subscriber couldn't connect it's okay - if let Ok(subscriber) = Endpoint::connect(&pipe_name).await { - let codec = rpc::PayloadCodec::new(); - let _frames = Framed::new(subscriber, codec); - } + let pipe_name_clone = pipe_name.clone(); + let sub_handle = tokio::spawn(async move { + // If the subscriber couldn't connect it's okay + if let Ok(subscriber) = Endpoint::connect(&pipe_name_clone).await { + let codec = rpc::PayloadCodec::new(); + let mut frames = Framed::new(subscriber, codec); + while let Some(_) = frames.next().await {} + } + }); + sub_handle.abort(); } let subscriber = Endpoint::connect(&pipe_name) @@ -821,45 +824,50 @@ mod tests { let codec = rpc::PayloadCodec::new(); let mut frames = Framed::new(subscriber, codec); - // Handle the SERVER_HELLO message - let message = frames.next().await.unwrap(); - if let Ok(payload) = message { - assert_eq!(payload.ctx, 9000); - } else { - panic!("Did not properly receive Server Hello"); - } + handle_server_hello(&mut frames).await; + + let num_received = Arc::new(AtomicU32::new(0)); + let num_received_clone = num_received.clone(); + let subscriber_handle = tokio::spawn(async move { + // Test that every packet was received in order + while let Some(result) = frames.next().await { + let payload = result.unwrap(); + let num_received_val = num_received_clone.fetch_add(1, Ordering::SeqCst); + assert_eq!( + payload.ctx, num_received_val, + "Received data from pipe does not match expected index!" + ); + } + }); // Send two packets for i in 0..2 { - let mut dummy_data = Vec::from([0u8; 5000]); - rand::thread_rng().fill(&mut dummy_data[..]); - server .broadcast(rpc::Payload { op: rpc::MessageOps::Debug, ctx: i, - data: dummy_data, + data: Vec::new(), }) .await; } - let mut num_received = 0u32; - // Test that every packet was received in order - loop { - select! { - data = frames.next() => { - let payload = data.expect("Expected data from frames!").unwrap(); - assert_eq!( - payload.ctx, num_received, - "Received data from pipe does not match expected index!" - ); - num_received += 1; - } - _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => { - assert_eq!(num_received, 2, "two packets should be received"); - return; - } + // Give some time for the subscriber to process the messages + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + subscriber_handle.abort(); + if let Err(e) = subscriber_handle.await { + if !e.is_cancelled() { + panic!("Test failed with assertion: {}", e); } } + + let num_received_val = num_received.load(Ordering::SeqCst); + assert_eq!( + num_received_val, 2, + "two packets should be received by the subscriber" + ); + + server_handle.abort(); + let _ = server_handle.await; } }