diff --git a/README.md b/README.md index 7505bfb..bd0717b 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,9 @@ types: | 1 | Zone | | 2 | Chat | + +For the `Debug` OP, the CHANNEL 9000 indicates the `HELLO` CHANNEL. + ### Data For payloads with OP `Debug`, the payload is simply debug-logged. @@ -117,8 +120,25 @@ CHANNEL. ### Debug OP -Any payload sent with the `Debug` OP will be simply be debug-logged and an `OK` -response will be sent back to the requesting subscriber. +Any payload sent with the `Debug` OP with a CHANNEL other than `HELLO` will be +simply be debug-logged and an `OK` response will be sent back to the requesting +subscriber. + +For payloads with the `Debug` OP and the `HELLO` channel, the DATA sent will +be set as the nickname for the current subscriber. + +These are the restrictions on the values of nickname that are accepted: +- The bytes in DATA must be valid UTF-8. +- Must contain only ASCII alphanumeric characters or underscores (`^[A-Za-z0-9_]+$`) +- Must be 30 characters or less +- Names are allowed to be the same as other subscribers. + +For example: +```c +Payload { OP: OP.Debug, CHANNEL: 9000, DATA: u8"TEST_CLIENT" } +// Deucalion: Nickname set response +Payload { OP: OP.Debug, CHANNEL: 9000, DATA: u8"CHANGED NICKNAME: TEST_CLIENT (subscriber 0)" } +``` ### Ping OP diff --git a/src/lib.rs b/src/lib.rs index 3a0b5de..2370a09 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,7 +43,6 @@ const SEND_SIG: &str = "E8 $ { ' } 8B 53 2C 48 8D 8B"; const SEND_LOBBY_SIG: &str = "40 53 48 83 EC 20 44 8B 41 28"; fn handle_payload(payload: rpc::Payload, hs: Arc) -> Result<()> { - info!("Received payload from subscriber: {:?}", payload); if payload.op == rpc::MessageOps::Recv || payload.op == rpc::MessageOps::Send { let hook_type = match payload.op { rpc::MessageOps::Recv => hook::HookType::Recv, @@ -148,7 +147,7 @@ async fn main_with_result() -> Result<()> { drop(shutdown_tx); info!("Shutting down broadcast loop..."); msg_loop_handle.await?; - info!("Shut down!"); + info!("Shutting down..."); Ok(()) } @@ -189,7 +188,7 @@ fn logging_setup() -> Result<()> { #[cfg(debug_assertions)] { - let _ = CombinedLogger::init(vec![ + CombinedLogger::init(vec![ SimpleLogger::new(LevelFilter::Debug, simplelog::Config::default()), WriteLogger::new(LevelFilter::Debug, simplelog::Config::default(), log_file), ])?; @@ -220,6 +219,7 @@ unsafe extern "system" fn main(dll_base_addr: LPVOID) -> u32 { error!("Panic happened: {:?}", cause); pause(); } + info!("Shut down!"); #[cfg(debug_assertions)] wincon::FreeConsole(); libloaderapi::FreeLibraryAndExitThread(dll_base_addr as HMODULE, 0); diff --git a/src/server.rs b/src/server.rs index 446520d..5bc656b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -89,6 +89,110 @@ where } } +const HELLO_CHANNEL: u32 = 9000; + +fn dbg_payload(ctx: u32, data: Vec) -> Payload { + let op = MessageOps::Debug; + Payload { op, ctx, data } +} + +fn ping_payload() -> Payload { + Payload { + op: MessageOps::Ping, + ctx: 0, + data: Vec::new(), + } +} + +/// Checks to make sure that the UTF-8 string is 30 characters or less and is +/// ASCII alphanumeric with underscores allowwed +fn validate_nickname(nickname: &String) -> Result<()> { + if nickname.len() > 30 { + return Err(format_err!("Nickname exceeds 30 chars: {nickname:?}").into()); + } + if !nickname + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_') + { + return Err(format_err!("Nickname contains invalid characters: {nickname:?}").into()); + } + Ok(()) +} + +impl Subscriber +where + T: AsyncRead + AsyncWrite + std::marker::Unpin, +{ + async fn send_dbg_payload(&mut self, ctx: u32, data: Vec) -> Result<()> { + self.frames.send(dbg_payload(ctx, data)).await?; + Ok(()) + } + + /// Handle the nickname sent from the subscriber and send a success/failure + /// response back. + async fn handle_nickname(&mut self, payload: Payload, nickname: &mut String) -> Result<()> + where + T: AsyncRead + AsyncWrite + std::marker::Unpin, + { + let bytes_str = format!("{:?}", payload.data); + match String::from_utf8(payload.data) { + Ok(nickname_str) => { + if let Err(e) = validate_nickname(&nickname_str) { + self.send_dbg_payload( + HELLO_CHANNEL, + format!("INVALID NICKNAME: \"{nickname_str}\"").into(), + ) + .await?; + return Err(e); + } + *nickname = format!("{nickname_str} (subscriber {})", self.id); + self.send_dbg_payload( + HELLO_CHANNEL, + format!("CHANGED NICKNAME: {nickname}").into(), + ) + .await?; + } + Err(e) => { + self.send_dbg_payload( + HELLO_CHANNEL, + format!("INVALID NICKNAME: {bytes_str}").into(), + ) + .await?; + return Err(e.into()); + } + } + Ok(()) + } + + /// Handle the payload from subscriber and send a success/failure response back + async fn handle_payload(&mut self, payload: Payload, payload_handler: &F) -> Result<()> + where + F: Fn(Payload) -> Result<()>, + { + let ctx = payload.ctx; + + let ack_prefix = { + match payload.op { + MessageOps::Recv => "RECV ", + MessageOps::Send => "SEND ", + _ => "", + } + }; + + match payload_handler(payload) { + Ok(()) => { + self.send_dbg_payload(ctx, format!("{ack_prefix}OK").into()) + .await? + } + Err(e) => { + self.send_dbg_payload(ctx, format!("{ack_prefix}{e}").into()) + .await? + } + } + Ok(()) + } +} + #[repr(u32)] enum BroadcastFilter { AllowLobbyRecv = 1, @@ -207,70 +311,20 @@ impl Server { } } - /// Handle the payload from subscriber and send a success/failure response back - async fn handle_payload_from_subscriber( - payload: Payload, - subscriber: &mut Subscriber, - payload_handler: &F, - ) -> Result<()> - where - T: AsyncRead + AsyncWrite + std::marker::Unpin, - F: Fn(Payload) -> Result<()>, - { - let ctx = payload.ctx; - - let ack_prefix = { - match payload.op { - MessageOps::Recv => "RECV ", - MessageOps::Send => "SEND ", - _ => "", - } - }; - - match payload_handler(payload) { - Ok(()) => { - subscriber - .frames - .send(Payload { - op: MessageOps::Debug, - ctx, - data: format!("{}OK", ack_prefix).into_bytes(), - }) - .await? - } - Err(e) => { - subscriber - .frames - .send(Payload { - op: MessageOps::Debug, - ctx, - data: format!("{}{}", ack_prefix, e).into_bytes(), - }) - .await? - } - } - Ok(()) - } - /// Subscriber message loop. Returns an error if it has trouble writing /// to the connection. /// Returns true if it returned because of an Exit request. Returns false /// if the connection was closed naturally. async fn subscriber_msg_loop( &self, - mut subscriber: &mut Subscriber, + subscriber: &mut Subscriber, + nickname: &mut String, payload_handler: F, ) -> Result where T: AsyncRead + AsyncWrite + std::marker::Unpin, F: Fn(Payload) -> Result<()>, { - let ping_payload = Payload { - op: MessageOps::Ping, - ctx: 0, - data: Vec::new(), - }; - // Default packet filter is AllowZoneRecv only let mut filter: u32 = BroadcastFilter::AllowZoneRecv as u32; @@ -280,7 +334,7 @@ impl Server { // A request was received from the current subscriber Ok(Message::Request(payload)) => match payload.op { MessageOps::Ping => { - subscriber.frames.send(ping_payload.clone()).await?; + subscriber.frames.send(ping_payload()).await?; } MessageOps::Exit => { info!("Shutting down server because Exit payload received"); @@ -290,25 +344,26 @@ impl Server { MessageOps::Option => { filter = payload.ctx; subscriber - .frames - .send(Payload { - op: MessageOps::Debug, - ctx: 0, - data: format!("Packet filters set: {filter:#010b}").into_bytes(), - }) + .send_dbg_payload( + 0, + format!("Packet filters set: {filter:#010b}").into(), + ) .await?; - info!( - "Filter set for subscriber {}: {filter:#010b}", - subscriber.id - ); + info!("Filter set for {nickname}: {filter:#010b}"); + } + MessageOps::Debug if payload.ctx == HELLO_CHANNEL => { + if let Err(e) = subscriber.handle_nickname(payload, nickname).await { + error!("Error setting nickname for {nickname}: {e}") + } else { + info!( + "Changed nickname for subscriber {} to {nickname}", + subscriber.id + ) + } } _ => { - Self::handle_payload_from_subscriber( - payload, - &mut subscriber, - &payload_handler, - ) - .await?; + info!("Received payload from {nickname}: {:?}", payload); + subscriber.handle_payload(payload, &payload_handler).await?; } }, @@ -320,8 +375,7 @@ impl Server { } Err(e) => { error!( - "An error occured while processing messages for subscriber {}; error = {:?}", - subscriber.id, e + "An error occured while processing messages for {nickname}; error = {e}", ); } } @@ -346,23 +400,17 @@ impl Server { let hello_string = self.state.lock().await.server_hello_string(); let mut subscriber = Subscriber { id, frames, rx }; + let mut nickname = format!("subscriber {}", subscriber.id); subscriber - .frames - .send( - Payload { - op: MessageOps::Debug, - ctx: 9000, - data: hello_string.into_bytes(), - } - .into(), - ) - .await?; + .send_dbg_payload(HELLO_CHANNEL, hello_string.into()) + .await + .map_err(|e| format_err!("Could not send SERVER HELLO to {nickname}: {e}"))?; - info!("New subscriber connected: {}", subscriber.id); + info!("New subscriber connected: {nickname}"); match self - .subscriber_msg_loop(&mut subscriber, payload_handler) + .subscriber_msg_loop(&mut subscriber, &mut nickname, payload_handler) .await { Ok(server_exit) => { @@ -370,16 +418,13 @@ impl Server { return Ok(()); } } - Err(e) => error!( - "Disconecting subscriber {} because of error: {}", - subscriber.id, e - ), + Err(e) => error!("Disconecting {} because of error: {}", nickname, e), } // If this section is reached it means that the subscriber was // disconnected one way or another. { - info!("Subscriber disconnected: {}", subscriber.id); + info!("Disconnected: {}", nickname); let mut state = self.state.lock().await; state.subscribers.remove(&subscriber.id); // Exit once all subscribers are disconnected @@ -500,6 +545,43 @@ mod tests { } } + #[test] + fn test_nickname_validation() { + let nickname_tests = [ + ("Inquisitor1234", None), + ("Names_with_underscores", None), + ( + "Names with spaces", + Some("Nickname contains invalid characters: \"Names with spaces\""), + ), + ( + "Names.with.punctuation.marks", + Some("Nickname contains invalid characters: \"Names.with.punctuation.marks\""), + ), + ( + "ASCIIではありません", + Some("Nickname contains invalid characters: \"ASCIIではありません\""), + ), + ( + "This_name_is_over_30_chars_long", + Some("Nickname exceeds 30 chars: \"This_name_is_over_30_chars_long\""), + ), + ]; + for (nickname, expected_err) in nickname_tests { + match validate_nickname(&nickname.to_string()) { + Ok(()) => { + if let Some(err_msg) = expected_err { + panic!("Expected validation for {nickname} to error with {err_msg}"); + } + } + Err(e) => match expected_err { + Some(msg) => assert_eq!(e.to_string(), msg), + None => panic!("Expected validation for {nickname} to succeed but got err {e}"), + }, + } + } + } + #[tokio::test(flavor = "multi_thread")] #[timeout(10_000)] async fn test_server_hello_message() { @@ -561,7 +643,7 @@ mod tests { // Handle the SERVER_HELLO message let message = frames.next().await.unwrap(); if let Ok(payload) = message { - assert_eq!(payload.ctx, 9000); + assert_eq!(payload.ctx, HELLO_CHANNEL); } else { panic!("Did not properly receive Server Hello"); } @@ -604,7 +686,7 @@ mod tests { "Packet filters set: 0b00100110", ); } else { - panic!("Did not properly receive Server Hello"); + panic!("Did not properly receive packet filter confirmation"); } let configurations = vec![ @@ -642,6 +724,82 @@ mod tests { let _ = server_handle.await; } + #[tokio::test(flavor = "multi_thread")] + #[timeout(10_000)] + async fn test_subscriber_nickname() { + let (_, pipe_name, server_handle) = run_server().await; + + struct Testcase { + nickname: Vec, + expected_resp: &'static str, + } + let testcases: Vec<(Vec, &str)> = vec![ + ( + "Inquisitor1234".into(), + "CHANGED NICKNAME: Inquisitor1234 (subscriber 0)", + ), + ( + "Names_with_underscores".into(), + "CHANGED NICKNAME: Names_with_underscores (subscriber 0)", + ), + ( + // Invalid UTF-8 string not allowed + vec![0, 150, 200, 250], + "INVALID NICKNAME: [0, 150, 200, 250]", + ), + ( + // Names with spaces not allowed + "Names with spaces".into(), + "INVALID NICKNAME: \"Names with spaces\"", + ), + ( + // Names with other symbols not allowed + "Names.with.punctuation.marks".into(), + "INVALID NICKNAME: \"Names.with.punctuation.marks\"", + ), + ( + // Names over 30 characters not allowed + "This_name_is_over_30_chars_long".into(), + "INVALID NICKNAME: \"This_name_is_over_30_chars_long\"", + ), + ]; + let test_handle = tokio::spawn(async move { + let subscriber = Endpoint::connect(&pipe_name) + .await + .expect("Failed to connect subscriber to server"); + + let codec = PayloadCodec::new(); + let mut frames = Framed::new(subscriber, codec); + + handle_server_hello(&mut frames).await; + + for (nickname, expected_resp) in testcases { + frames + .send(dbg_payload(HELLO_CHANNEL, nickname)) + .await + .unwrap(); + + let message = frames.next().await.unwrap(); + if let Ok(payload) = message { + assert_eq!(payload.op, MessageOps::Debug); + assert_eq!( + String::from_utf8(payload.data).unwrap(), + expected_resp, + "Expected response did not match" + ); + } else { + panic!("Did not receive subscriber nickname confirmation"); + } + } + }); + + test_handle.await.expect("Test failed with assertion"); + server_handle.abort(); + + // Wait on the server to shut down + let _ = server_handle.await; + } + #[tokio::test(flavor = "multi_thread")] #[timeout(10_000)] async fn test_server_shutdown() { @@ -745,13 +903,7 @@ mod tests { let mut dummy_data = Vec::from([0u8; 5000]); rand::thread_rng().fill(&mut dummy_data[..]); - server - .broadcast(Payload { - op: MessageOps::Debug, - ctx: i, - data: dummy_data, - }) - .await; + server.broadcast(dbg_payload(i, dummy_data)).await; } // Give some time for the subscriber to process the messages @@ -797,13 +949,7 @@ mod tests { let mut dummy_data = Vec::from([0u8; 5000]); rand::thread_rng().fill(&mut dummy_data[..]); - server - .broadcast(Payload { - op: MessageOps::Debug, - ctx: i, - data: dummy_data, - }) - .await; + server.broadcast(dbg_payload(i, dummy_data)).await; } // Test that every packet was received in order @@ -874,13 +1020,7 @@ mod tests { // Send two packets for i in 0..2 { - server - .broadcast(Payload { - op: MessageOps::Debug, - ctx: i, - data: Vec::new(), - }) - .await; + server.broadcast(dbg_payload(i, Vec::new())).await; } // Give some time for the subscriber to process the messages