diff --git a/Cargo.lock b/Cargo.lock index 121eadb2..014334bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -373,6 +373,7 @@ dependencies = [ "tokio-tungstenite 0.20.0", "tower", "tracing", + "tracing-subscriber", "unicode-segmentation", ] @@ -1263,11 +1264,13 @@ dependencies = [ "futures", "http", "http-body", + "hyper", "itertools 0.11.0", "serde", "serde_json", "thiserror", "tokio", + "tokio-tungstenite 0.20.0", "tower", "tower-http", "tracing", diff --git a/e2e/src/engineioxide.rs b/e2e/src/engineioxide.rs index ec206545..620c090b 100644 --- a/e2e/src/engineioxide.rs +++ b/e2e/src/engineioxide.rs @@ -3,7 +3,10 @@ use std::time::Duration; use engineioxide::{ - config::EngineIoConfig, handler::EngineIoHandler, service::EngineIoService, socket::Socket, + config::EngineIoConfig, + handler::EngineIoHandler, + service::EngineIoService, + socket::{DisconnectReason, Socket}, }; use hyper::Server; use tracing::{info, Level}; @@ -19,8 +22,8 @@ impl EngineIoHandler for MyHandler { fn on_connect(&self, socket: &Socket) { println!("socket connect {}", socket.sid); } - fn on_disconnect(&self, socket: &Socket) { - println!("socket disconnect {}", socket.sid); + fn on_disconnect(&self, socket: &Socket, reason: DisconnectReason) { + println!("socket disconnect {}: {:?}", socket.sid, reason); } fn on_message(&self, msg: String, socket: &Socket) { diff --git a/engineioxide/Cargo.toml b/engineioxide/Cargo.toml index d10e0569..a270e0f4 100644 --- a/engineioxide/Cargo.toml +++ b/engineioxide/Cargo.toml @@ -45,7 +45,16 @@ unicode-segmentation = { version = "1.10.1", optional = true } [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } -tokio = { version = "1.26.0", features = ["macros"] } +tokio = { version = "1.26.0", features = ["macros", "parking_lot"] } +tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +hyper = { version = "0.14.25", features = [ + "http1", + "http2", + "server", + "stream", + "runtime", + "client", +] } [features] default = ["v4"] diff --git a/engineioxide/benches/benchmark_polling.rs b/engineioxide/benches/benchmark_polling.rs index 4626d481..03b17080 100644 --- a/engineioxide/benches/benchmark_polling.rs +++ b/engineioxide/benches/benchmark_polling.rs @@ -3,6 +3,7 @@ use std::time::Duration; use bytes::{Buf, Bytes}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use engineioxide::socket::DisconnectReason; use engineioxide::{handler::EngineIoHandler, service::EngineIoService, socket::Socket}; use engineioxide::sid_generator::Sid; @@ -29,7 +30,7 @@ impl EngineIoHandler for Client { fn on_connect(&self, _: &Socket) {} - fn on_disconnect(&self, _: &Socket) {} + fn on_disconnect(&self, _: &Socket, _reason: DisconnectReason) {} fn on_message(&self, msg: String, socket: &Socket) { socket.emit(msg).unwrap(); diff --git a/engineioxide/src/config.rs b/engineioxide/src/config.rs index 03982ea9..5898c037 100644 --- a/engineioxide/src/config.rs +++ b/engineioxide/src/config.rs @@ -95,7 +95,7 @@ impl EngineIoConfigBuilder { /// # use engineioxide::{ /// layer::EngineIoLayer, /// handler::EngineIoHandler, - /// socket::Socket, + /// socket::{Socket, DisconnectReason}, /// }; /// # use std::sync::Arc; /// #[derive(Debug, Clone)] @@ -108,7 +108,7 @@ impl EngineIoConfigBuilder { /// fn on_connect(&self, socket: &Socket) { /// println!("socket connect {}", socket.sid); /// } - /// fn on_disconnect(&self, socket: &Socket) { + /// fn on_disconnect(&self, socket: &Socket, reason: DisconnectReason) { /// println!("socket disconnect {}", socket.sid); /// } /// diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index fb331257..b75f1a87 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -14,7 +14,7 @@ use crate::{ payload::{self}, service::TransportType, sid_generator::generate_sid, - socket::{ConnectionType, Socket, SocketReq}, + socket::{ConnectionType, DisconnectReason, Socket, SocketReq}, }; use crate::{service::ProtocolVersion, sid_generator::Sid}; use futures::{stream::SplitStream, SinkExt, StreamExt, TryStreamExt}; @@ -62,7 +62,8 @@ impl EngineIo { B: Send + 'static, { let engine = self.clone(); - let close_fn = Box::new(move |sid: Sid| engine.close_session(sid)); + let close_fn = + Box::new(move |sid: Sid, reason: DisconnectReason| engine.close_session(sid, reason)); let sid = generate_sid(); let socket = Socket::new( sid, @@ -124,7 +125,7 @@ impl EngineIo { let rx = match socket.internal_rx.try_lock() { Ok(s) => s, Err(_) => { - socket.close(); + socket.close(DisconnectReason::MultipleHttpPollingError); return Err(Error::HttpErrorResponse(StatusCode::BAD_REQUEST)); } }; @@ -168,7 +169,7 @@ impl EngineIo { Ok(Packet::Close) => { debug!("[sid={sid}] closing session"); socket.send(Packet::Noop)?; - self.close_session(sid); + self.close_session(sid, DisconnectReason::TransportClose); break; } Ok(Packet::Pong) | Ok(Packet::Ping) => socket @@ -189,7 +190,7 @@ impl EngineIo { } Err(e) => { debug!("[sid={sid}] error parsing packet: {:?}", e); - self.close_session(sid); + self.close_session(sid, DisconnectReason::PacketParsingError); return Err(e); } }?; @@ -256,7 +257,9 @@ impl EngineIo { } else { let sid = generate_sid(); let engine = self.clone(); - let close_fn = Box::new(move |sid: Sid| engine.close_session(sid)); + let close_fn = Box::new(move |sid: Sid, reason: DisconnectReason| { + engine.close_session(sid, reason) + }); let socket = Socket::new( sid, protocol, @@ -305,10 +308,14 @@ impl EngineIo { }); self.handler.on_connect(&socket); - if let Err(e) = self.ws_forward_to_handler(rx, &socket).await { + if let Err(ref e) = self.ws_forward_to_handler(rx, &socket).await { debug!("[sid={}] error when handling packet: {:?}", socket.sid, e); + if let Some(reason) = e.into() { + self.close_session(socket.sid, reason); + } + } else { + self.close_session(socket.sid, DisconnectReason::TransportClose); } - self.close_session(socket.sid); rx_handle.abort(); Ok(()) } @@ -319,13 +326,12 @@ impl EngineIo { mut rx: SplitStream>, socket: &Arc>, ) -> Result<(), Error> { - while let Ok(msg) = rx.try_next().await { - let Some(msg) = msg else { continue }; + while let Some(msg) = rx.try_next().await? { match msg { Message::Text(msg) => match Packet::try_from(msg)? { Packet::Close => { debug!("[sid={}] closing session", socket.sid); - self.close_session(socket.sid); + self.close_session(socket.sid, DisconnectReason::TransportClose); break; } Packet::Pong | Packet::Ping => socket @@ -448,10 +454,10 @@ impl EngineIo { /// Close an engine.io session by removing the socket from the socket map and closing the socket /// It should be the only way to close a session and to remove a socket from the socket map - fn close_session(&self, sid: Sid) { + fn close_session(&self, sid: Sid, reason: DisconnectReason) { let socket = self.sockets.write().unwrap().remove(&sid); if let Some(socket) = socket { - self.handler.on_disconnect(&socket); + self.handler.on_disconnect(&socket, reason); socket.abort_heartbeat(); debug!( "remaining sockets: {:?}", @@ -486,8 +492,8 @@ mod tests { println!("socket connect {}", socket.sid); } - fn on_disconnect(&self, socket: &Socket) { - println!("socket disconnect {}", socket.sid); + fn on_disconnect(&self, socket: &Socket, reason: DisconnectReason) { + println!("socket disconnect {} {:?}", socket.sid, reason); } fn on_message(&self, msg: String, socket: &Socket) { diff --git a/engineioxide/src/errors.rs b/engineioxide/src/errors.rs index c4218e50..ff66b036 100644 --- a/engineioxide/src/errors.rs +++ b/engineioxide/src/errors.rs @@ -20,8 +20,6 @@ pub enum Error { BadPacket(Packet), #[error("ws transport error: {0:?}")] WsTransport(#[from] tungstenite::Error), - #[error("http transport error: {0:?}")] - HttpTransport(#[from] hyper::Error), #[error("http error: {0:?}")] Http(#[from] http::Error), #[error("internal channel error: {0:?}")] diff --git a/engineioxide/src/handler.rs b/engineioxide/src/handler.rs index 3e0a698e..36fb8033 100644 --- a/engineioxide/src/handler.rs +++ b/engineioxide/src/handler.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; -use crate::socket::Socket; +use crate::socket::{DisconnectReason, Socket}; /// An handler for engine.io events for each sockets. #[async_trait] @@ -12,7 +12,7 @@ pub trait EngineIoHandler: std::fmt::Debug + Send + Sync + Clone + 'static { fn on_connect(&self, socket: &Socket); /// Called when a socket is disconnected. - fn on_disconnect(&self, socket: &Socket); + fn on_disconnect(&self, socket: &Socket, reason: DisconnectReason); /// Called when a message is received from the client. fn on_message(&self, msg: String, socket: &Socket); diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index 102656a5..f4d6a4c6 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -216,6 +216,22 @@ impl FromStr for TransportType { } } } +impl From for &'static str { + fn from(t: TransportType) -> Self { + match t { + TransportType::Polling => "polling", + TransportType::Websocket => "websocket", + } + } +} +impl From for String { + fn from(t: TransportType) -> Self { + match t { + TransportType::Polling => "polling".into(), + TransportType::Websocket => "websocket".into(), + } + } +} #[derive(Debug, Copy, Clone, PartialEq)] pub enum ProtocolVersion { diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 6252c286..03a61f5a 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -11,6 +11,7 @@ use tokio::{ sync::{mpsc, mpsc::Receiver, Mutex}, task::JoinHandle, }; +use tokio_tungstenite::tungstenite; use tracing::debug; use crate::sid_generator::Sid; @@ -54,6 +55,39 @@ impl From for SocketReq { } } +/// A [`DisconnectReason`] represents the reason why a [`Socket`] was closed. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DisconnectReason { + /// The client gracefully closed the connection + TransportClose, + /// The client sent multiple polling requests at the same time (it is forbidden according to the engine.io protocol) + MultipleHttpPollingError, + /// The client sent a bad request / the packet could not be parsed correctly + PacketParsingError, + /// An error occured in the transport layer + /// (e.g. the client closed the connection without sending a close packet) + TransportError, + /// The client did not respond to the heartbeat + HeartbeatTimeout, +} + +/// Convert an [`Error`] to a [`DisconnectReason`] if possible +/// This is used to notify the [`Handler`](crate::handler::EngineIoHandler) of the reason why a [`Socket`] was closed +/// If the error cannot be converted to a [`DisconnectReason`] it means that the error was not fatal and the [`Socket`] can be kept alive +impl From<&Error> for Option { + fn from(err: &Error) -> Self { + use Error::*; + match err { + WsTransport(tungstenite::Error::ConnectionClosed) => None, + WsTransport(_) | Io(_) => Some(DisconnectReason::TransportError), + BadPacket(_) | Serialize(_) | Base64(_) | StrUtf8(_) | PayloadTooLarge + | InvalidPacketLength => Some(DisconnectReason::PacketParsingError), + HeartbeatTimeout => Some(DisconnectReason::HeartbeatTimeout), + _ => None, + } + } +} + /// A [`Socket`] represents a connection to the server. /// It is agnostic to the [`TransportType`](crate::service::TransportType). /// It handles : @@ -99,7 +133,7 @@ where heartbeat_handle: Mutex>>, /// Function to call when the socket is closed - close_fn: Box, + close_fn: Box, /// User data bound to the socket pub data: H::Data, @@ -121,7 +155,7 @@ where conn: ConnectionType, config: &EngineIoConfig, req_data: SocketReq, - close_fn: Box, + close_fn: Box, #[cfg(feature = "v3")] supports_binary: bool, ) -> Self { let (internal_tx, internal_rx) = mpsc::channel(config.max_buffer_size); @@ -174,7 +208,7 @@ where let handle = tokio::spawn(async move { if let Err(e) = socket.heartbeat_job(interval, timeout).await { - socket.close(); + socket.close(DisconnectReason::HeartbeatTimeout); debug!("[sid={}] heartbeat error: {:?}", socket.sid, e); } }); @@ -300,8 +334,8 @@ where /// Immediately closes the socket and the underlying connection. /// The socket will be removed from the `Engine` and the [`Handler`](crate::handler::EngineIoHandler) will be notified. - pub fn close(&self) { - (self.close_fn)(self.sid); + pub fn close(&self, reason: DisconnectReason) { + (self.close_fn)(self.sid, reason); self.send(Packet::Close).ok(); } @@ -325,7 +359,10 @@ where #[cfg(test)] impl Socket { - pub fn new_dummy(sid: Sid, close_fn: Box) -> Socket { + pub fn new_dummy( + sid: Sid, + close_fn: Box, + ) -> Socket { let (internal_tx, internal_rx) = mpsc::channel(200); let (tx, rx) = mpsc::channel(200); let (heartbeat_tx, heartbeat_rx) = mpsc::channel(1); diff --git a/engineioxide/tests/disconnect_reason.rs b/engineioxide/tests/disconnect_reason.rs new file mode 100644 index 00000000..0fceaa2a --- /dev/null +++ b/engineioxide/tests/disconnect_reason.rs @@ -0,0 +1,184 @@ +//! Tests for disconnect reasons +//! Test are made on polling and websocket transports: +//! * Heartbeat timeout +//! * Transport close +//! * Multiple http polling +//! * Packet parsing + +use std::time::Duration; + +use engineioxide::{ + handler::EngineIoHandler, + socket::{DisconnectReason, Socket}, +}; +use futures::SinkExt; +use tokio::sync::mpsc; + +mod fixture; + +use fixture::{create_server, send_req}; +use tokio_tungstenite::tungstenite::Message; + +use crate::fixture::{create_polling_connection, create_ws_connection}; + +#[derive(Debug, Clone)] +struct MyHandler { + disconnect_tx: mpsc::Sender, +} + +#[engineioxide::async_trait] +impl EngineIoHandler for MyHandler { + type Data = (); + + fn on_connect(&self, socket: &Socket) { + println!("socket connect {}", socket.sid); + } + fn on_disconnect(&self, socket: &Socket, reason: DisconnectReason) { + println!("socket disconnect {}: {:?}", socket.sid, reason); + self.disconnect_tx.try_send(reason).unwrap(); + } + + fn on_message(&self, msg: String, socket: &Socket) { + println!("Ping pong message {:?}", msg); + socket.emit(msg).ok(); + } + + fn on_binary(&self, data: Vec, socket: &Socket) { + println!("Ping pong binary message {:?}", data); + socket.emit_binary(data).ok(); + } +} + +#[tokio::test] +pub async fn polling_heartbeat_timeout() { + let (disconnect_tx, mut rx) = mpsc::channel(10); + create_server(MyHandler { disconnect_tx }, 1234); + create_polling_connection(1234).await; + + let data = tokio::time::timeout(Duration::from_millis(500), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::HeartbeatTimeout") + .unwrap(); + + assert_eq!(data, DisconnectReason::HeartbeatTimeout); +} + +#[tokio::test] +pub async fn ws_heartbeat_timeout() { + let (disconnect_tx, mut rx) = mpsc::channel(10); + create_server(MyHandler { disconnect_tx }, 12344); + let _stream = create_ws_connection(12344).await; + + let data = tokio::time::timeout(Duration::from_millis(500), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::HeartbeatTimeout") + .unwrap(); + + assert_eq!(data, DisconnectReason::HeartbeatTimeout); +} + +#[tokio::test] +pub async fn polling_transport_closed() { + let (disconnect_tx, mut rx) = mpsc::channel(10); + create_server(MyHandler { disconnect_tx }, 1235); + let sid = create_polling_connection(1235).await; + + send_req( + 1235, + format!("transport=polling&sid={sid}"), + http::Method::POST, + Some("1".into()), + ) + .await; + + let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::TransportClose") + .unwrap(); + + assert_eq!(data, DisconnectReason::TransportClose); +} + +#[tokio::test] +pub async fn ws_transport_closed() { + let (disconnect_tx, mut rx) = mpsc::channel(10); + create_server(MyHandler { disconnect_tx }, 12345); + let mut stream = create_ws_connection(12345).await; + + stream.send(Message::Text("1".into())).await.unwrap(); + + let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::TransportClose") + .unwrap(); + + assert_eq!(data, DisconnectReason::TransportClose); +} + +#[tokio::test] +pub async fn multiple_http_polling() { + let (disconnect_tx, mut rx) = mpsc::channel(10); + create_server(MyHandler { disconnect_tx }, 1236); + let sid = create_polling_connection(1236).await; + + tokio::spawn(futures::future::join_all(vec![ + send_req( + 1236, + format!("transport=polling&sid={sid}"), + http::Method::GET, + None, + ), + send_req( + 1236, + format!("transport=polling&sid={sid}"), + http::Method::GET, + None, + ), + ])); + + let data = tokio::time::timeout(Duration::from_millis(10), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::DisconnectError::MultipleHttpPolling") + .unwrap(); + + assert_eq!(data, DisconnectReason::MultipleHttpPollingError); +} + +#[tokio::test] +pub async fn polling_packet_parsing() { + let (disconnect_tx, mut rx) = mpsc::channel(10); + create_server(MyHandler { disconnect_tx }, 1237); + let sid = create_polling_connection(1237).await; + send_req( + 1237, + format!("transport=polling&sid={sid}"), + http::Method::POST, + Some("aizdunazidaubdiz".into()), + ) + .await; + + let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::PacketParsingError") + .unwrap(); + + assert_eq!(data, DisconnectReason::PacketParsingError); +} + +#[tokio::test] +pub async fn ws_packet_parsing() { + let (disconnect_tx, mut rx) = mpsc::channel(10); + create_server(MyHandler { disconnect_tx }, 12347); + let mut stream = create_ws_connection(12347).await; + stream + .send(Message::Text("aizdunazidaubdiz".into())) + .await + .unwrap(); + + let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::TransportError::PacketParsing") + .unwrap(); + + assert_eq!(data, DisconnectReason::PacketParsingError); +} diff --git a/engineioxide/tests/fixture.rs b/engineioxide/tests/fixture.rs new file mode 100644 index 00000000..20d36ecf --- /dev/null +++ b/engineioxide/tests/fixture.rs @@ -0,0 +1,80 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Duration, +}; + +use bytes::Buf; +use engineioxide::{config::EngineIoConfig, handler::EngineIoHandler, service::EngineIoService}; +use http::Request; +use hyper::Server; +use serde::{Deserialize, Serialize}; +use tokio::net::TcpStream; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; + +/// An OpenPacket is used to initiate a connection +#[derive(Debug, Serialize, Deserialize, PartialEq, PartialOrd)] +#[serde(rename_all = "camelCase")] +struct OpenPacket { + sid: String, + upgrades: Vec, + ping_interval: u64, + ping_timeout: u64, + max_payload: u64, +} + +/// Params should be in the form of `key1=value1&key2=value2` +pub async fn send_req( + port: u16, + params: String, + method: http::Method, + body: Option, +) -> String { + let body = body + .map(|b| hyper::Body::from(b)) + .unwrap_or_else(hyper::Body::empty); + let req = Request::builder() + .method(method) + .uri(format!( + "http://127.0.0.1:{port}/engine.io/?EIO=4&{}", + params + )) + .body(body) + .unwrap(); + let mut res = hyper::Client::new().request(req).await.unwrap(); + let body = hyper::body::aggregate(res.body_mut()).await.unwrap(); + String::from_utf8(body.chunk().to_vec()) + .unwrap() + .chars() + .skip(1) + .collect() +} + +pub async fn create_polling_connection(port: u16) -> String { + let body = send_req(port, format!("transport=polling"), http::Method::GET, None).await; + let open_packet: OpenPacket = serde_json::from_str(&body).unwrap(); + open_packet.sid +} +pub async fn create_ws_connection(port: u16) -> WebSocketStream> { + tokio_tungstenite::connect_async(format!( + "ws://127.0.0.1:{port}/engine.io/?EIO=4&transport=websocket" + )) + .await + .unwrap() + .0 +} + +pub fn create_server(handler: H, port: u16) { + let config = EngineIoConfig::builder() + .ping_interval(Duration::from_millis(300)) + .ping_timeout(Duration::from_millis(200)) + .max_payload(1e6 as u64) + .build(); + + let addr = &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); + + let svc = EngineIoService::with_config(handler, config); + + let server = Server::bind(addr).serve(svc.into_make_service()); + + tokio::spawn(server); +} diff --git a/examples/src/chat/handlers.rs b/examples/src/chat/handlers.rs index 4d9b0b40..da253393 100644 --- a/examples/src/chat/handlers.rs +++ b/examples/src/chat/handlers.rs @@ -89,4 +89,11 @@ pub async fn handler(socket: Arc>) { ); socket.to("default").emit("message", msg).ok(); }); + + socket.on_disconnect(|socket, reason| async move { + info!("Socket disconnected: {} {}", socket.sid, reason); + let Nickname(ref nickname) = *socket.extensions.get().unwrap(); + let msg = format!("{} left the chat", nickname); + socket.to("default").emit("message", msg).ok(); + }); } diff --git a/examples/src/engineio-echo/axum_echo.rs b/examples/src/engineio-echo/axum_echo.rs index 42def30c..e4c90735 100644 --- a/examples/src/engineio-echo/axum_echo.rs +++ b/examples/src/engineio-echo/axum_echo.rs @@ -1,6 +1,10 @@ use axum::routing::get; use axum::Server; -use engineioxide::{handler::EngineIoHandler, layer::EngineIoLayer, socket::Socket}; +use engineioxide::{ + handler::EngineIoHandler, + layer::EngineIoLayer, + socket::{DisconnectReason, Socket}, +}; use tracing::info; use tracing_subscriber::FmtSubscriber; @@ -14,8 +18,8 @@ impl EngineIoHandler for MyHandler { fn on_connect(&self, socket: &Socket) { println!("socket connect {}", socket.sid); } - fn on_disconnect(&self, socket: &Socket) { - println!("socket disconnect {}", socket.sid); + fn on_disconnect(&self, socket: &Socket, reason: DisconnectReason) { + println!("socket disconnect {}: {:?}", socket.sid, reason); } fn on_message(&self, msg: String, socket: &Socket) { diff --git a/examples/src/engineio-echo/hyper_echo.rs b/examples/src/engineio-echo/hyper_echo.rs index 63485ff0..bd79f271 100644 --- a/examples/src/engineio-echo/hyper_echo.rs +++ b/examples/src/engineio-echo/hyper_echo.rs @@ -1,4 +1,8 @@ -use engineioxide::{handler::EngineIoHandler, service::EngineIoService, socket::Socket}; +use engineioxide::{ + handler::EngineIoHandler, + service::EngineIoService, + socket::{DisconnectReason, Socket}, +}; use hyper::Server; use tracing::info; use tracing_subscriber::FmtSubscriber; @@ -13,8 +17,8 @@ impl EngineIoHandler for MyHandler { fn on_connect(&self, socket: &Socket) { println!("socket connect {}", socket.sid); } - fn on_disconnect(&self, socket: &Socket) { - println!("socket disconnect {}", socket.sid); + fn on_disconnect(&self, socket: &Socket, reason: DisconnectReason) { + println!("socket disconnect {}: {:?}", socket.sid, reason); } fn on_message(&self, msg: String, socket: &Socket) { diff --git a/examples/src/engineio-echo/warp_echo.rs b/examples/src/engineio-echo/warp_echo.rs index f9da5a5b..a0e4f689 100644 --- a/examples/src/engineio-echo/warp_echo.rs +++ b/examples/src/engineio-echo/warp_echo.rs @@ -1,4 +1,8 @@ -use engineioxide::{handler::EngineIoHandler, service::EngineIoService, socket::Socket}; +use engineioxide::{ + handler::EngineIoHandler, + service::EngineIoService, + socket::{DisconnectReason, Socket}, +}; use hyper::Server; use tracing::info; use tracing_subscriber::FmtSubscriber; @@ -14,8 +18,8 @@ impl EngineIoHandler for MyHandler { fn on_connect(&self, socket: &Socket) { println!("socket connect {}", socket.sid); } - fn on_disconnect(&self, socket: &Socket) { - println!("socket disconnect {}", socket.sid); + fn on_disconnect(&self, socket: &Socket, reason: DisconnectReason) { + println!("socket disconnect {}: {:?}", socket.sid, reason); } fn on_message(&self, msg: String, socket: &Socket) { diff --git a/examples/src/socketio-echo/axum_echo.rs b/examples/src/socketio-echo/axum_echo.rs index f9c33c68..cf9c5991 100644 --- a/examples/src/socketio-echo/axum_echo.rs +++ b/examples/src/socketio-echo/axum_echo.rs @@ -24,6 +24,10 @@ async fn main() -> Result<(), Box> { info!("Received event: {:?} {:?}", data, bin); ack.bin(bin).send(data).ok(); }); + + socket.on_disconnect(|socket, reason| async move { + info!("Socket.IO disconnected: {} {}", socket.sid, reason); + }); }) .add("/custom", |socket| async move { info!("Socket.IO connected on: {:?} {:?}", socket.ns(), socket.sid); diff --git a/examples/src/socketio-echo/hyper_echo.rs b/examples/src/socketio-echo/hyper_echo.rs index 7dcddf50..aa783da1 100644 --- a/examples/src/socketio-echo/hyper_echo.rs +++ b/examples/src/socketio-echo/hyper_echo.rs @@ -23,6 +23,10 @@ async fn main() -> Result<(), Box> { info!("Received event: {:?} {:?}", data, bin); ack.bin(bin).send(data).ok(); }); + + socket.on_disconnect(|socket, reason| async move { + info!("Socket.IO disconnected: {} {}", socket.sid, reason); + }); }) .add("/custom", |socket| async move { info!("Socket.IO connected on: {:?} {:?}", socket.ns(), socket.sid); diff --git a/examples/src/socketio-echo/warp_echo.rs b/examples/src/socketio-echo/warp_echo.rs index 77bba517..e81fe3a5 100644 --- a/examples/src/socketio-echo/warp_echo.rs +++ b/examples/src/socketio-echo/warp_echo.rs @@ -24,6 +24,10 @@ async fn main() -> Result<(), Box> { info!("Received event: {:?} {:?}", data, bin); ack.bin(bin).send(data).ok(); }); + + socket.on_disconnect(|socket, reason| async move { + info!("Socket.IO disconnected: {} {}", socket.sid, reason); + }); }) .add("/custom", |socket| async move { info!("Socket.IO connected on: {:?} {:?}", socket.ns(), socket.sid); diff --git a/socketioxide/Cargo.toml b/socketioxide/Cargo.toml index d4e4e6b0..1cdc8961 100644 --- a/socketioxide/Cargo.toml +++ b/socketioxide/Cargo.toml @@ -34,4 +34,14 @@ dashmap = "5.4.0" [dev-dependencies] axum = "0.6.18" -tracing-subscriber = "0.3.17" +tokio = { version = "1.26.0", features = ["macros", "parking_lot"] } +tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +tokio-tungstenite = "0.20.0" +hyper = { version = "0.14.25", features = [ + "http1", + "http2", + "server", + "stream", + "runtime", + "client", +] } diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index 786b9b2b..14cadab7 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use engineioxide::handler::EngineIoHandler; -use engineioxide::socket::Socket as EIoSocket; +use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Socket as EIoSocket}; use serde_json::Value; use engineioxide::sid_generator::Sid; @@ -119,13 +119,16 @@ impl EngineIoHandler for Client { fn on_connect(&self, socket: &EIoSocket) { debug!("eio socket connect {}", socket.sid); } - fn on_disconnect(&self, socket: &EIoSocket) { + fn on_disconnect(&self, socket: &EIoSocket, reason: EIoDisconnectReason) { debug!("eio socket disconnect {}", socket.sid); - self.ns.values().for_each(|ns| { - if let Err(e) = ns.remove_socket(socket.sid) { - error!("Adapter error when disconnecting {}: {}, in a multiple server scenario it could leads to desyncronisation issues", socket.sid, e); - } - }); + let data = self + .ns + .values() + .filter_map(|ns| ns.get_socket(socket.sid).ok()) + .map(|s| s.close(reason.clone().into())); + if let Err(e) = data.collect::, _>>() { + error!("Adapter error when disconnecting {}: {}, in a multiple server scenario it could leads to desyncronisation issues", socket.sid, e); + } } fn on_message(&self, msg: String, socket: &EIoSocket) { @@ -134,7 +137,7 @@ impl EngineIoHandler for Client { Ok(packet) => packet, Err(e) => { debug!("socket serialization error: {}", e); - socket.close(); + socket.close(EIoDisconnectReason::PacketParsingError); return; } }; @@ -150,9 +153,14 @@ impl EngineIoHandler for Client { } _ => self.sock_propagate_packet(packet, socket.sid), }; - if let Err(err) = res { - error!("error while processing packet: {:?}", err); - socket.close(); + if let Err(ref err) = res { + error!( + "error while processing packet to socket {}: {}", + socket.sid, err + ); + if let Some(reason) = err.into() { + socket.close(reason); + } } } @@ -162,12 +170,14 @@ impl EngineIoHandler for Client { fn on_binary(&self, data: Vec, socket: &EIoSocket) { if self.apply_payload_on_packet(data, socket) { if let Some(packet) = socket.data.partial_bin_packet.lock().unwrap().take() { - if let Err(e) = self.sock_propagate_packet(packet, socket.sid) { + if let Err(ref err) = self.sock_propagate_packet(packet, socket.sid) { debug!( "error while propagating packet to socket {}: {}", - socket.sid, e + socket.sid, err ); - socket.close(); + if let Some(reason) = err.into() { + socket.close(reason); + } } } } diff --git a/socketioxide/src/errors.rs b/socketioxide/src/errors.rs index 1aa6bcd4..781c4f5a 100644 --- a/socketioxide/src/errors.rs +++ b/socketioxide/src/errors.rs @@ -1,5 +1,5 @@ use crate::{adapter::Adapter, packet::Packet, socket::RetryablePacket, Socket}; -use engineioxide::sid_generator::Sid; +use engineioxide::{sid_generator::Sid, socket::DisconnectReason as EIoDisconnectReason}; use std::{ fmt::{Debug, Display}, sync::Arc, @@ -18,9 +18,6 @@ pub enum Error { #[error("invalid event name")] InvalidEventName, - #[error("cannot find socketio engine")] - EngineGone, - #[error("cannot find socketio socket")] SocketGone(Sid), @@ -32,6 +29,23 @@ pub enum Error { Adapter(#[from] AdapterError), } +/// Convert an [`Error`] to an [`EIoDisconnectReason`] if possible +/// +/// If the error cannot be converted to a [`DisconnectReason`] it means that the error was not fatal and the engine `Socket` can be kept alive +impl From<&Error> for Option { + fn from(value: &Error) -> Self { + use EIoDisconnectReason::*; + match value { + Error::SocketGone(_) => Some(TransportClose), + Error::EngineIoError(ref e) => e.into(), + Error::SerializeError(_) | Error::InvalidPacketType | Error::InvalidEventName => { + Some(PacketParsingError) + } + Error::Adapter(_) => None, + } + } +} + /// Error type for ack responses #[derive(thiserror::Error, Debug)] pub enum AckError { diff --git a/socketioxide/src/lib.rs b/socketioxide/src/lib.rs index 412b7b28..db139179 100644 --- a/socketioxide/src/lib.rs +++ b/socketioxide/src/lib.rs @@ -42,6 +42,11 @@ //! println!("Received acb event: {:?} {:?}", data, bin); //! ack.bin(bin).send(data).ok(); //! }); +//! // Add a callback triggered when the socket disconnect +//! // The reason of the disconnection will be passed to the callback +//! socket.on_disconnect(|socket, reason| async move { +//! println!("Socket.IO disconnected: {} {}", socket.sid, reason); +//! }); //! }) //! .add("/custom", |socket| async move { //! println!("Socket connected on /custom namespace with id: {}", socket.sid); @@ -68,9 +73,9 @@ pub use errors::{ AckError, AckSenderError, BroadcastError, Error as SocketError, SendError, TransportError, }; pub use layer::SocketIoLayer; -pub use ns::Namespace; +pub use ns::{Namespace, NsHandlers}; pub use service::SocketIoService; -pub use socket::Socket; +pub use socket::{DisconnectReason, Socket}; mod client; mod config; diff --git a/socketioxide/src/ns.rs b/socketioxide/src/ns.rs index 84d501a7..c060541e 100644 --- a/socketioxide/src/ns.rs +++ b/socketioxide/src/ns.rs @@ -3,12 +3,12 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::errors::{AdapterError, SendError}; +use crate::errors::AdapterError; use crate::{ adapter::{Adapter, LocalAdapter}, errors::Error, handshake::Handshake, - packet::{Packet, PacketData}, + packet::PacketData, socket::Socket, SocketIoConfig, }; @@ -67,15 +67,7 @@ impl Namespace { socket } - pub fn disconnect(&self, sid: Sid) -> Result<(), SendError> { - if let Some(socket) = self.sockets.write().unwrap().remove(&sid) { - self.adapter - .del_all(sid) - .map_err(|err| AdapterError(Box::new(err)))?; - socket.send(Packet::disconnect(self.path.clone()))?; - } - Ok(()) - } + /// Remove a socket from a namespace and propagate the event to the adapter pub fn remove_socket(&self, sid: Sid) -> Result<(), AdapterError> { self.sockets.write().unwrap().remove(&sid); self.adapter @@ -87,19 +79,11 @@ impl Namespace { self.sockets.read().unwrap().values().any(|s| s.sid == sid) } - /// Called when a namespace receive a particular packet that should be transmitted to the socket - pub fn socket_recv(&self, sid: Sid, packet: PacketData) -> Result<(), Error> { - self.get_socket(sid)?.recv(packet) - } - pub fn recv(&self, sid: Sid, packet: PacketData) -> Result<(), Error> { match packet { - PacketData::Disconnect => self - .remove_socket(sid) - .map_err(|err| AdapterError(Box::new(err)).into()), PacketData::Connect(_) => unreachable!("connect packets should be handled before"), PacketData::ConnectError(_) => Ok(()), - packet => self.socket_recv(sid, packet), + packet => self.get_socket(sid)?.recv(packet), } } pub fn get_socket(&self, sid: Sid) -> Result>, Error> { diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index 8fc0e204..9f43d73a 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -10,15 +10,17 @@ use std::{ time::Duration, }; -use engineioxide::{sid_generator::Sid, SendPacket as EnginePacket}; -use futures::Future; +use engineioxide::{ + sid_generator::Sid, socket::DisconnectReason as EIoDisconnectReason, SendPacket as EnginePacket, +}; +use futures::{future::BoxFuture, Future}; use serde::{de::DeserializeOwned, Serialize}; use serde_json::Value; use tokio::sync::mpsc::error::TrySendError; use tokio::sync::oneshot; use tracing::debug; -use crate::errors::{SendError, TransportError}; +use crate::errors::{AdapterError, SendError, TransportError}; use crate::{ adapter::{Adapter, Room}, errors::{AckError, Error}, @@ -31,13 +33,72 @@ use crate::{ SocketIoConfig, }; +pub type DisconnectCallback = Box< + dyn FnOnce(Arc>, DisconnectReason) -> BoxFuture<'static, ()> + Send + Sync + 'static, +>; + +/// All the possible reasons for a [`Socket`] to be disconnected. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum DisconnectReason { + /// The client gracefully closed the connection + TransportClose, + + /// The client sent multiple polling requests at the same time (it is forbidden according to the engine.io protocol) + MultipleHttpPollingError, + + /// The client sent a bad request / the packet could not be parsed correctly + PacketParsingError, + + /// The connection was closed (example: the user has lost connection, or the network was changed from WiFi to 4G) + TransportError, + + /// The client did not send a PONG packet in the [ping timeout](crate::SocketIoConfigBuilder) delay + HeartbeatTimeout, + + /// The client has manually disconnected the socket using [`socket.disconnect()`](https://socket.io/fr/docs/v4/client-api/#socketdisconnect) + ClientNSDisconnect, + + /// The socket was forcefully disconnected from the namespace with [`Socket::disconnect`] + ServerNSDisconnect, +} + +impl std::fmt::Display for DisconnectReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use DisconnectReason::*; + let str: &'static str = match self { + TransportClose => "client gracefully closed the connection", + MultipleHttpPollingError => "client sent multiple polling requests at the same time", + PacketParsingError => "client sent a bad request / the packet could not be parsed", + TransportError => "The connection was abruptly closed", + HeartbeatTimeout => "client did not send a PONG packet in time", + ClientNSDisconnect => "client has manually disconnected the socket from the namespace", + ServerNSDisconnect => "socket was forcefully disconnected from the namespace", + }; + f.write_str(str) + } +} + +impl From for DisconnectReason { + fn from(reason: EIoDisconnectReason) -> Self { + use DisconnectReason::*; + match reason { + EIoDisconnectReason::TransportClose => TransportClose, + EIoDisconnectReason::TransportError => TransportError, + EIoDisconnectReason::HeartbeatTimeout => HeartbeatTimeout, + EIoDisconnectReason::MultipleHttpPollingError => MultipleHttpPollingError, + EIoDisconnectReason::PacketParsingError => PacketParsingError, + } + } +} + /// A Socket represents a client connected to a namespace. /// It is used to send and receive messages from the client, join and leave rooms, etc. pub struct Socket { config: Arc, ns: Arc>, message_handlers: RwLock>>, - ack_message: RwLock>>>, + disconnect_handler: Mutex>>, + ack_message: Mutex>>>, ack_counter: AtomicI64, pub handshake: Handshake, pub sid: Sid, @@ -194,7 +255,8 @@ impl Socket { Self { ns, message_handlers: RwLock::new(HashMap::new()), - ack_message: RwLock::new(HashMap::new()), + disconnect_handler: Mutex::new(None), + ack_message: Mutex::new(HashMap::new()), ack_counter: AtomicI64::new(0), handshake, sid, @@ -269,6 +331,31 @@ impl Socket { .insert(event.into(), MessageHandler::boxed(handler)); } + /// ## Register a disconnect handler. + /// The callback will be called when the socket is disconnected from the server or the client or when the underlying connection crashes. + /// A [`DisconnectReason`](crate::DisconnectReason) is passed to the callback to indicate the reason for the disconnection. + /// ### Example + /// ``` + /// # use socketioxide::Namespace; + /// # use serde_json::Value; + /// Namespace::builder().add("/", |socket| async move { + /// socket.on("test", |socket, data: Value, bin, _| async move { + /// // Close the current socket + /// socket.disconnect().ok(); + /// }); + /// socket.on_disconnect(|socket, reason| async move { + /// println!("Socket {} on ns {} disconnected, reason: {:?}", socket.sid, socket.ns(), reason); + /// }); + /// }); + pub fn on_disconnect(&self, callback: C) + where + C: Fn(Arc>, DisconnectReason) -> F + Send + Sync + 'static, + F: Future + Send + 'static, + { + let handler = Box::new(move |s, r| Box::pin(callback(s, r)) as _); + *self.disconnect_handler.lock().unwrap() = Some(handler); + } + /// Emit a message to the client /// ##### Example /// ``` @@ -503,9 +590,13 @@ impl Socket { Operators::new(self.ns.clone(), self.sid).broadcast() } - /// Disconnect the socket from the current namespace. - pub fn disconnect(&self) -> Result<(), SendError> { - self.ns.disconnect(self.sid) + /// Disconnect the socket from the current namespace, + /// + /// It will also call the disconnect handler if it is set. + pub fn disconnect(self: Arc) -> Result<(), SendError> { + self.send(Packet::disconnect(self.ns.path.clone()))?; + self.close(DisconnectReason::ServerNSDisconnect)?; + Ok(()) } /// Get the current namespace path. @@ -524,7 +615,7 @@ impl Socket { ) -> Result, AckError> { let (tx, rx) = oneshot::channel(); let ack = self.ack_counter.fetch_add(1, Ordering::SeqCst) + 1; - self.ack_message.write().unwrap().insert(ack, tx); + self.ack_message.lock().unwrap().insert(ack, tx); packet.inner.set_ack_id(ack); self.send(packet)?; let timeout = timeout.unwrap_or(self.config.ack_timeout); @@ -532,14 +623,26 @@ impl Socket { Ok((serde_json::from_value(v.0)?, v.1)) } - // Receive data from client: + /// Called when the socket is gracefully disconnected from the server or the client + /// + /// It maybe also closed when the underlying transport is closed or failed. + pub(crate) fn close(self: Arc, reason: DisconnectReason) -> Result<(), AdapterError> { + if let Some(handler) = self.disconnect_handler.lock().unwrap().take() { + tokio::spawn(handler(self.clone(), reason)); + } + self.ns.remove_socket(self.sid) + } + // Receive data from client: pub(crate) fn recv(self: Arc, packet: PacketData) -> Result<(), Error> { match packet { PacketData::Event(e, data, ack) => self.recv_event(e, data, ack), PacketData::EventAck(data, ack_id) => self.recv_ack(data, ack_id), PacketData::BinaryEvent(e, packet, ack) => self.recv_bin_event(e, packet, ack), PacketData::BinaryAck(packet, ack) => self.recv_bin_ack(packet, ack), + PacketData::Disconnect => self + .close(DisconnectReason::ClientNSDisconnect) + .map_err(Error::from), _ => unreachable!(), } } @@ -564,14 +667,14 @@ impl Socket { } fn recv_ack(self: Arc, data: Value, ack: i64) -> Result<(), Error> { - if let Some(tx) = self.ack_message.write().unwrap().remove(&ack) { + if let Some(tx) = self.ack_message.lock().unwrap().remove(&ack) { tx.send((data, vec![])).ok(); } Ok(()) } fn recv_bin_ack(self: Arc, packet: BinaryPacket, ack: i64) -> Result<(), Error> { - if let Some(tx) = self.ack_message.write().unwrap().remove(&ack) { + if let Some(tx) = self.ack_message.lock().unwrap().remove(&ack) { tx.send((packet.data, packet.bin)).ok(); } Ok(()) diff --git a/socketioxide/tests/disconnect_reason.rs b/socketioxide/tests/disconnect_reason.rs new file mode 100644 index 00000000..035ac82c --- /dev/null +++ b/socketioxide/tests/disconnect_reason.rs @@ -0,0 +1,236 @@ +//! Tests for disconnect reasons +//! Test are made on polling and websocket transports for engine.io errors and only websocket for socket.io errors: +//! * Heartbeat timeout +//! * Transport close +//! * Multiple http polling +//! * Packet parsing +//! +//! * Client namespace disconnect +//! * Server namespace disconnect + +use std::time::Duration; + +use futures::SinkExt; +use socketioxide::{adapter::LocalAdapter, DisconnectReason, Namespace, NsHandlers}; +use tokio::sync::mpsc; + +mod fixture; + +use fixture::{create_server, send_req}; +use tokio_tungstenite::tungstenite::Message; + +use crate::fixture::{create_polling_connection, create_ws_connection}; + +fn create_handler() -> (NsHandlers, mpsc::Receiver) { + let (tx, rx) = mpsc::channel::(1); + let ns = Namespace::builder() + .add("/", move |socket| { + println!("Socket connected on / namespace with id: {}", socket.sid); + let tx = tx.clone(); + socket.on_disconnect(move |socket, reason| { + println!("Socket.IO disconnected: {} {}", socket.sid, reason); + tx.try_send(reason).unwrap(); + async move {} + }); + + async move {} + }) + .build(); + (ns, rx) +} + +// Engine IO Disconnect Reason Tests + +#[tokio::test] +pub async fn polling_heartbeat_timeout() { + let (ns, mut rx) = create_handler(); + create_server(ns, 1234); + create_polling_connection(1234).await; + + let data = tokio::time::timeout(Duration::from_millis(500), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::HeartbeatTimeout") + .unwrap(); + + assert_eq!(data, DisconnectReason::HeartbeatTimeout); +} + +#[tokio::test] +pub async fn ws_heartbeat_timeout() { + let (ns, mut rx) = create_handler(); + create_server(ns, 12344); + let _stream = create_ws_connection(12344).await; + + let data = tokio::time::timeout(Duration::from_millis(500), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::HeartbeatTimeout") + .unwrap(); + + assert_eq!(data, DisconnectReason::HeartbeatTimeout); +} + +#[tokio::test] +pub async fn polling_transport_closed() { + let (ns, mut rx) = create_handler(); + create_server(ns, 1235); + let sid = create_polling_connection(1235).await; + + send_req( + 1235, + format!("transport=polling&sid={sid}"), + http::Method::POST, + Some("1".into()), + ) + .await; + + let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::TransportClose") + .unwrap(); + + assert_eq!(data, DisconnectReason::TransportClose); +} + +#[tokio::test] +pub async fn ws_transport_closed() { + let (ns, mut rx) = create_handler(); + create_server(ns, 12345); + let mut stream = create_ws_connection(12345).await; + + stream.send(Message::Text("1".into())).await.unwrap(); + + let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::TransportClose") + .unwrap(); + + assert_eq!(data, DisconnectReason::TransportClose); +} + +#[tokio::test] +pub async fn multiple_http_polling() { + let (ns, mut rx) = create_handler(); + create_server(ns, 1236); + let sid = create_polling_connection(1236).await; + + // First request to flush the server buffer containing the open packet + send_req( + 1236, + format!("transport=polling&sid={sid}"), + http::Method::GET, + None, + ) + .await; + + tokio::spawn(futures::future::join_all(vec![ + send_req( + 1236, + format!("transport=polling&sid={sid}"), + http::Method::GET, + None, + ), + send_req( + 1236, + format!("transport=polling&sid={sid}"), + http::Method::GET, + None, + ), + ])); + + let data = tokio::time::timeout(Duration::from_millis(10), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::DisconnectError::MultipleHttpPolling") + .unwrap(); + + assert_eq!(data, DisconnectReason::MultipleHttpPollingError); +} + +#[tokio::test] +pub async fn polling_packet_parsing() { + let (ns, mut rx) = create_handler(); + create_server(ns, 1237); + let sid = create_polling_connection(1237).await; + send_req( + 1237, + format!("transport=polling&sid={sid}"), + http::Method::POST, + Some("aizdunazidaubdiz".into()), + ) + .await; + + let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::PacketParsingError") + .unwrap(); + + assert_eq!(data, DisconnectReason::PacketParsingError); +} + +#[tokio::test] +pub async fn ws_packet_parsing() { + let (ns, mut rx) = create_handler(); + create_server(ns, 12347); + let mut stream = create_ws_connection(12347).await; + stream + .send(Message::Text("aizdunazidaubdiz".into())) + .await + .unwrap(); + + let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::PacketParsingError") + .unwrap(); + + assert_eq!(data, DisconnectReason::PacketParsingError); +} + +// Socket IO Disconnect Reason Tests + +#[tokio::test] +pub async fn client_ns_disconnect() { + let (ns, mut rx) = create_handler(); + create_server(ns, 12348); + let mut stream = create_ws_connection(12348).await; + + stream.send(Message::Text("41".into())).await.unwrap(); + + let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::ClientNSDisconnect") + .unwrap(); + + assert_eq!(data, DisconnectReason::ClientNSDisconnect); +} + +#[tokio::test] +pub async fn server_ns_disconnect() { + let (tx, mut rx) = mpsc::channel::(1); + let ns = Namespace::builder() + .add("/", move |socket| { + println!("Socket connected on / namespace with id: {}", socket.sid); + let sock = socket.clone(); + let tx = tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + sock.disconnect().unwrap(); + }); + + socket.on_disconnect(move |socket, reason| { + println!("Socket.IO disconnected: {} {}", socket.sid, reason); + tx.try_send(reason).unwrap(); + async move {} + }); + + async move {} + }) + .build(); + + create_server(ns, 12349); + let _stream = create_ws_connection(12349).await; + + let data = tokio::time::timeout(Duration::from_millis(20), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::ServerNSDisconnect") + .unwrap(); + assert_eq!(data, DisconnectReason::ServerNSDisconnect); +} diff --git a/socketioxide/tests/fixture.rs b/socketioxide/tests/fixture.rs new file mode 100644 index 00000000..e6a675ec --- /dev/null +++ b/socketioxide/tests/fixture.rs @@ -0,0 +1,92 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Duration, +}; + +use futures::SinkExt; +use http::Request; +use hyper::{body::Buf, Server}; +use serde::{Deserialize, Serialize}; +use socketioxide::{adapter::LocalAdapter, NsHandlers, SocketIoConfig, SocketIoService}; +use tokio::net::TcpStream; +use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream}; +/// An OpenPacket is used to initiate a connection +#[derive(Debug, Serialize, Deserialize, PartialEq, PartialOrd)] +#[serde(rename_all = "camelCase")] +struct OpenPacket { + sid: String, + upgrades: Vec, + ping_interval: u64, + ping_timeout: u64, + max_payload: u64, +} + +/// Params should be in the form of `key1=value1&key2=value2` +pub async fn send_req( + port: u16, + params: String, + method: http::Method, + body: Option, +) -> String { + let body = body + .map(|b| hyper::Body::from(b)) + .unwrap_or_else(hyper::Body::empty); + let req = Request::builder() + .method(method) + .uri(format!( + "http://127.0.0.1:{port}/socket.io/?EIO=4&{}", + params + )) + .body(body) + .unwrap(); + let mut res = hyper::Client::new().request(req).await.unwrap(); + let body = hyper::body::aggregate(res.body_mut()).await.unwrap(); + String::from_utf8(body.chunk().to_vec()) + .unwrap() + .chars() + .skip(1) + .collect() +} + +pub async fn create_polling_connection(port: u16) -> String { + let body = send_req(port, format!("transport=polling"), http::Method::GET, None).await; + let open_packet: OpenPacket = serde_json::from_str(&body).unwrap(); + + send_req( + port, + format!("transport=polling&sid={}", open_packet.sid), + http::Method::POST, + Some("40{}".to_string()), + ) + .await; + + open_packet.sid +} +pub async fn create_ws_connection(port: u16) -> WebSocketStream> { + let mut ws = tokio_tungstenite::connect_async(format!( + "ws://127.0.0.1:{port}/socket.io/?EIO=4&transport=websocket" + )) + .await + .unwrap() + .0; + + ws.send(Message::Text("40{}".to_string())).await.unwrap(); + + ws +} + +pub fn create_server(ns: NsHandlers, port: u16) { + let config = SocketIoConfig::builder() + .ping_interval(Duration::from_millis(300)) + .ping_timeout(Duration::from_millis(200)) + .max_payload(1e6 as u64) + .build(); + + let addr = &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); + + let svc = SocketIoService::with_config(ns, config); + + let server = Server::bind(addr).serve(svc.into_make_service()); + + tokio::spawn(server); +}