diff --git a/.github/workflows/engineio-ci.yml b/.github/workflows/engineio-ci.yml new file mode 100644 index 00000000..499098fc --- /dev/null +++ b/.github/workflows/engineio-ci.yml @@ -0,0 +1,39 @@ + +name: EngineIO CI + +on: + push: + tags: + - v* + pull_request: + branches: + - main + - develop + +jobs: + e2e_v3: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + - uses: actions/checkout@v3 + with: + repository: socketio/engine.io-protocol + path: engine.io-protocol + ref: v3 + - uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Install deps & run tests + run: | + cd engine.io-protocol/test-suite && npm install && cd ../.. + cargo run --bin engineioxide-e2e --release --features v3 --no-default-features > v3_server.txt & npm --prefix engine.io-protocol/test-suite test > v3_client.txt + - name: Server output + if: always() + run: cat v3_server.txt + - name: Client output + if: always() + run: cat v3_client.txt \ No newline at end of file diff --git a/.github/workflows/socketio-ci.yml b/.github/workflows/socketio-ci.yml index 2900da8a..7ef0dccf 100644 --- a/.github/workflows/socketio-ci.yml +++ b/.github/workflows/socketio-ci.yml @@ -27,7 +27,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: toolchain: stable - - run: cargo test + - run: cargo test --all-features e2e: runs-on: ubuntu-latest diff --git a/engineioxide/Cargo.toml b/engineioxide/Cargo.toml index 1256242e..123b2c80 100644 --- a/engineioxide/Cargo.toml +++ b/engineioxide/Cargo.toml @@ -36,4 +36,9 @@ criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } [[bench]] name = "benchmark_polling" -harness = false \ No newline at end of file +harness = false + +[features] +default = ["v4"] +v4 = [] +v3 = [] diff --git a/engineioxide/Readme.md b/engineioxide/Readme.md index a5b9ce7b..e5f360ef 100644 --- a/engineioxide/Readme.md +++ b/engineioxide/Readme.md @@ -47,3 +47,23 @@ async fn main() -> Result<(), Box> { Ok(()) } ``` + +### Supported Protocols +You can enable support for other EngineIO protocol implementations through feature flags. +The latest supported protocol version is enabled by default. + +To add support for another protocol version, adjust your dependency configuration accordingly: + +```toml +[dependencies] +# Enables the `v3` protocol (`v4` is also implicitly enabled, as it's the default). +engineioxide = { version = "0.3.0", features = ["v3"] } +``` + +To enable *a single protocol version only*, disable default features: + +```toml +[dependencies] +# Enables the `v3` protocol only. +engineioxide = { version = "0.3.0", features = ["v3"], default-features = false } +``` diff --git a/engineioxide/src/config.rs b/engineioxide/src/config.rs index cae3cf72..601fd2a5 100644 --- a/engineioxide/src/config.rs +++ b/engineioxide/src/config.rs @@ -53,6 +53,7 @@ impl EngineIoConfigBuilder { config: EngineIoConfig::default(), } } + /// The path to listen for engine.io requests on. /// Defaults to "/engine.io". pub fn req_path(mut self, req_path: String) -> Self { diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 0aa379af..74c42cca 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -1,11 +1,9 @@ #![deny(clippy::await_holding_lock)] use std::{ collections::HashMap, - io::BufRead, sync::{Arc, RwLock}, }; -use crate::sid_generator::Sid; use crate::{ body::ResponseBody, config::EngineIoConfig, @@ -17,6 +15,11 @@ use crate::{ sid_generator::generate_sid, socket::{ConnectionType, Socket, SocketReq}, }; +use crate::{ + payload::{Payload, PACKET_SEPARATOR}, + service::ProtocolVersion, + sid_generator::Sid, +}; use bytes::Buf; use futures::{stream::SplitStream, SinkExt, StreamExt, TryStreamExt}; use http::{Request, Response, StatusCode}; @@ -55,6 +58,7 @@ impl EngineIo { /// Send an open packet pub(crate) fn on_open_http_req( self: Arc, + protocol: ProtocolVersion, req: Request, ) -> Result>, Error> where @@ -65,6 +69,7 @@ impl EngineIo { let sid = generate_sid(); let socket = Socket::new( sid, + protocol, ConnectionType::Http, &self.config, SocketReq::from(req.into_parts().0), @@ -81,6 +86,20 @@ impl EngineIo { let packet = OpenPacket::new(TransportType::Polling, sid, &self.config); let packet: String = Packet::Open(packet).try_into()?; + let packet = { + #[cfg(feature = "v3")] + { + let mut packet = packet; + // The V3 protocol requires the packet length to be prepended to the packet. + // It doesn't use a packet separator like the V4 protocol (and up). + if protocol == ProtocolVersion::V3 { + packet = format!("{}:{}", packet.chars().count(), packet); + } + packet + } + #[cfg(not(feature = "v3"))] + packet + }; http_response(StatusCode::OK, packet).map_err(Error::Http) } @@ -90,6 +109,7 @@ impl EngineIo { /// Otherwise it will wait for the next packet to be sent from the socket pub(crate) async fn on_polling_http_req( self: Arc, + protocol: ProtocolVersion, sid: Sid, ) -> Result>, Error> where @@ -118,7 +138,16 @@ impl EngineIo { debug!("sending packet: {:?}", packet); let packet: String = packet.try_into().unwrap(); if !data.is_empty() { - data.push('\x1e'); + // The V3 protocol requires the packet length to be prepended to the packet. + // The V4 protocol (and up) only requires a packet separator. + match protocol { + ProtocolVersion::V3 => data.push_str(&format!("{}:", packet.chars().count())), + ProtocolVersion::V4 => { + data.push(std::char::from_u32(PACKET_SEPARATOR as u32).unwrap()) + } + } + } else if protocol == ProtocolVersion::V3 { + data.push_str(&format!("{}:", packet.chars().count())); } data.push_str(&packet); } @@ -127,6 +156,13 @@ impl EngineIo { if data.is_empty() { let packet = rx.recv().await.ok_or(Error::Aborted)?; let packet: String = packet.try_into().unwrap(); + #[cfg(feature = "v3")] + { + // The V3 protocol specifically requires the packet length to be prepended to the packet. + if protocol == ProtocolVersion::V3 { + data.push_str(&format!("{}:", packet.chars().count())); + } + } data.push_str(&packet); } Ok(http_response(StatusCode::OK, data)?) @@ -137,6 +173,7 @@ impl EngineIo { /// Split the body into packets and send them to the internal socket pub(crate) async fn on_post_http_req( self: Arc, + protocol: ProtocolVersion, sid: Sid, body: Request, ) -> Result>, Error> @@ -150,35 +187,37 @@ impl EngineIo { debug!("error aggregating body: {:?}", e); Error::HttpErrorResponse(StatusCode::BAD_REQUEST) })?; - let packets = body.reader().split(b'\x1e'); let socket = self .get_socket(sid) .ok_or(Error::UnknownSessionID(sid)) .and_then(|s| s.is_http().then(|| s).ok_or(Error::TransportMismatch))?; - for packet in packets { - match Packet::try_from(packet?) { - Err(e) => { - debug!("[sid={sid}] error parsing packet: {:?}", e); - self.close_session(sid); - Err(e) - } + let packets = Payload::new(protocol, body.reader()); + + for p in packets { + let raw_packet = p.map_err(|e| { + debug!("error parsing packets: {:?}", e); + self.close_session(sid); + Error::HttpErrorResponse(StatusCode::BAD_REQUEST) + })?; + + match Packet::try_from(raw_packet) { Ok(Packet::Close) => { debug!("[sid={sid}] closing session"); socket.send(Packet::Noop)?; self.close_session(sid); break; } - Ok(Packet::Pong) => socket - .pong_tx + Ok(Packet::Pong) | Ok(Packet::Ping) => socket + .heartbeat_tx .try_send(()) .map_err(|_| Error::HeartbeatTimeout), Ok(Packet::Message(msg)) => { self.handler.on_message(msg, &socket); Ok(()) } - Ok(Packet::Binary(bin)) => { + Ok(Packet::Binary(bin)) | Ok(Packet::BinaryV3(bin)) => { self.handler.on_binary(bin, &socket); Ok(()) } @@ -186,6 +225,11 @@ impl EngineIo { debug!("[sid={sid}] bad packet received: {:?}", &p); Err(Error::BadPacket(p)) } + Err(e) => { + debug!("[sid={sid}] error parsing packet: {:?}", e); + self.close_session(sid); + return Err(e); + } }?; } Ok(http_response(StatusCode::OK, "ok")?) @@ -197,6 +241,7 @@ impl EngineIo { /// the http polling request is closed and the SID is kept for the websocket pub(crate) fn on_ws_req( self: Arc, + protocol: ProtocolVersion, sid: Option, req: Request, ) -> Result>, Error> { @@ -211,7 +256,7 @@ impl EngineIo { let req = Request::from_parts(parts, ()); tokio::spawn(async move { match hyper::upgrade::on(req).await { - Ok(conn) => match self.on_ws_req_init(conn, sid, req_data).await { + Ok(conn) => match self.on_ws_req_init(conn, protocol, sid, req_data).await { Ok(_) => debug!("ws closed"), Err(e) => debug!("ws closed with error: {:?}", e), }, @@ -230,6 +275,7 @@ impl EngineIo { async fn on_ws_req_init( self: Arc, conn: Upgraded, + protocol: ProtocolVersion, sid: Option, req_data: SocketReq, ) -> Result<(), Error> { @@ -241,7 +287,7 @@ impl EngineIo { Some(_) => { debug!("[sid={sid}] websocket connection upgrade"); let mut ws = ws_init().await; - self.ws_upgrade_handshake(sid, &mut ws).await?; + self.ws_upgrade_handshake(protocol, sid, &mut ws).await?; (self.get_socket(sid).unwrap(), ws) } } @@ -251,6 +297,7 @@ impl EngineIo { let close_fn = Box::new(move |sid: Sid| engine.close_session(sid)); let socket = Socket::new( sid, + protocol, ConnectionType::WebSocket, &self.config, req_data, @@ -276,7 +323,9 @@ impl EngineIo { let mut socket_rx = rx_socket.internal_rx.try_lock().unwrap(); while let Some(item) = socket_rx.recv().await { let res = match item { - Packet::Binary(bin) => tx.send(Message::Binary(bin)).await, + Packet::Binary(bin) | Packet::BinaryV3(bin) => { + tx.send(Message::Binary(bin)).await + } Packet::Close => tx.send(Message::Close(None)).await, _ => { let packet: String = item.try_into().unwrap(); @@ -315,8 +364,8 @@ impl EngineIo { self.close_session(socket.sid); break; } - Packet::Pong => socket - .pong_tx + Packet::Pong | Packet::Ping => socket + .heartbeat_tx .try_send(()) .map_err(|_| Error::HeartbeatTimeout), Packet::Message(msg) => { @@ -361,12 +410,19 @@ impl EngineIo { /// ``` async fn ws_upgrade_handshake( &self, + protocol: ProtocolVersion, sid: Sid, ws: &mut WebSocketStream, ) -> Result<(), Error> { let socket = self.get_socket(sid).unwrap(); - // send a NOOP packet to any pending polling request so it closes gracefully - socket.send(Packet::Noop)?; + + #[cfg(feature = "v4")] + { + // send a NOOP packet to any pending polling request so it closes gracefully' + if protocol == ProtocolVersion::V4 { + socket.send(Packet::Noop)?; + } + } // Fetch the next packet from the ws stream, it should be a PingUpgrade packet let msg = match ws.next().await { @@ -382,10 +438,27 @@ impl EngineIo { p => Err(Error::BadPacket(p))?, }; + #[cfg(feature = "v3")] + { + // send a NOOP packet to any pending polling request so it closes gracefully + // V3 protocol introduce _paused_ polling transport which require to close + // the polling request **after** the ping/pong handshake + if protocol == ProtocolVersion::V3 { + socket.send(Packet::Noop)?; + } + } + // Fetch the next packet from the ws stream, it should be an Upgrade packet let msg = match ws.next().await { Some(Ok(Message::Text(d))) => d, - _ => Err(Error::UpgradeError)?, + Some(Ok(Message::Close(_))) => { + debug!("ws stream closed before upgrade"); + Err(Error::UpgradeError)? + } + _ => { + debug!("unexpected ws message before upgrade"); + Err(Error::UpgradeError)? + } }; match Packet::try_from(msg)? { Packet::Upgrade => debug!("[sid={sid}] ws upgraded successful"), @@ -431,3 +504,36 @@ impl EngineIo { self.sockets.read().unwrap().get(&sid).cloned() } } + +#[cfg(test)] +mod tests { + use async_trait::async_trait; + + use super::*; + + #[derive(Debug, Clone)] + struct MockHandler; + + #[async_trait] + impl EngineIoHandler for MockHandler { + type Data = (); + + fn on_connect(&self, socket: &Socket) { + println!("socket connect {}", socket.sid); + } + + fn on_disconnect(&self, socket: &Socket) { + println!("socket disconnect {}", socket.sid); + } + + fn on_message(&self, msg: String, socket: &Socket) { + println!("Ping pong message {:?}", msg); + socket.emit(msg).ok(); + } + + fn on_binary(&self, data: Vec, socket: &Socket) { + println!("Ping pong binary message {:?}", data); + socket.emit_binary(data).ok(); + } + } +} diff --git a/engineioxide/src/errors.rs b/engineioxide/src/errors.rs index 5e307019..1adf9263 100644 --- a/engineioxide/src/errors.rs +++ b/engineioxide/src/errors.rs @@ -48,6 +48,9 @@ pub enum Error { TransportMismatch, #[error("unsupported protocol version")] UnsupportedProtocolVersion, + + #[error("Invalid packet length")] + InvalidPacketLength, } /// Convert an error into an http response diff --git a/engineioxide/src/lib.rs b/engineioxide/src/lib.rs index d36d737c..75e61e53 100644 --- a/engineioxide/src/lib.rs +++ b/engineioxide/src/lib.rs @@ -3,6 +3,9 @@ pub use async_trait::async_trait; /// A Packet type to use when sending data to the client pub use packet::SendPacket; +#[cfg(not(any(feature = "v3", feature = "v4")))] +compile_error!("At least one protocol version must be enabled"); + pub mod config; pub mod errors; pub mod handler; @@ -15,4 +18,5 @@ mod body; mod engine; mod futures; mod packet; +mod payload; mod utils; diff --git a/engineioxide/src/packet.rs b/engineioxide/src/packet.rs index 01376dee..706d6f05 100644 --- a/engineioxide/src/packet.rs +++ b/engineioxide/src/packet.rs @@ -47,6 +47,15 @@ pub enum Packet { /// /// When receiving, it is only used with polling connection, websocket use binary frame Binary(Vec), // Not part of the protocol, used internally + + /// Binary packet used to send binary data to the client + /// Converts to a String using base64 encoding when using polling connection + /// Or to a websocket binary frame when using websocket connection + /// + /// When receiving, it is only used with polling connection, websocket use binary frame + /// + /// This is a special packet, excepionally specific to the V3 protocol. + BinaryV3(Vec), // Not part of the protocol, used internally } /// Serialize a [Packet] to a [String] according to the Engine.IO protocol @@ -66,6 +75,7 @@ impl TryInto for Packet { Packet::Upgrade => "5".to_string(), Packet::Noop => "6".to_string(), Packet::Binary(data) => "b".to_string() + &general_purpose::STANDARD.encode(data), + Packet::BinaryV3(data) => "b4".to_string() + &general_purpose::STANDARD.encode(data), }; Ok(res) } @@ -75,9 +85,9 @@ impl TryFrom for Packet { type Error = crate::errors::Error; fn try_from(value: String) -> Result { let mut chars = value.chars(); - let packet_type = chars.next().ok_or(serde_json::Error::custom( - "Packet type not found in packet string", - ))?; + let packet_type = chars + .next() + .ok_or_else(|| serde_json::Error::custom("Packet type not found in packet string"))?; let packet_data = chars.as_str(); let is_upgrade = packet_data.starts_with("probe"); let res = match packet_type { @@ -100,6 +110,9 @@ impl TryFrom for Packet { '4' => Packet::Message(packet_data.to_string()), '5' => Packet::Upgrade, '6' => Packet::Noop, + 'b' if value.starts_with("b4") => { + Packet::BinaryV3(general_purpose::STANDARD.decode(packet_data[1..].as_bytes())?) + } 'b' => Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?), c => Err(serde_json::Error::custom( "Invalid packet type ".to_string() + &c.to_string(), @@ -231,6 +244,20 @@ mod tests { assert_eq!(packet, Packet::Binary(vec![1, 2, 3])); } + #[test] + fn test_binary_packet_v3() { + let packet = Packet::BinaryV3(vec![1, 2, 3]); + let packet_str: String = packet.try_into().unwrap(); + assert_eq!(packet_str, "b4AQID"); + } + + #[test] + fn test_binary_packet_v3_deserialize() { + let packet_str = "b4AQID".to_string(); + let packet: Packet = packet_str.try_into().unwrap(); + assert_eq!(packet, Packet::BinaryV3(vec![1, 2, 3])); + } + #[test] fn test_send_packet_into_packet() { let packet = SendPacket::Message("hello".to_string()); diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs new file mode 100644 index 00000000..bff81a0b --- /dev/null +++ b/engineioxide/src/payload.rs @@ -0,0 +1,146 @@ +use std::{io::BufRead, vec}; + +use crate::{errors::Error, service::ProtocolVersion}; + +pub const PACKET_SEPARATOR: u8 = b'\x1e'; + +/// A payload is a series of encoded packets tied together. +/// How packets are tied together depends on the protocol. +pub struct Payload { + reader: R, + buffer: Vec, + #[allow(dead_code)] + protocol: ProtocolVersion, +} + +type Item = Result; // TODO: refactor to return Result<&str, Error> instead (see TODO's below) + +impl Payload { + pub fn new(protocol: ProtocolVersion, data: R) -> Self { + Payload { + reader: data, + buffer: vec![], + protocol, + } + } + + #[cfg(feature = "v3")] + fn next_v3(&mut self) -> Option { + match self.reader.read_until(b':', &mut self.buffer) { + Ok(bytes_read) => (bytes_read > 0).then(|| { + if self.buffer.ends_with(&[b':']) { + self.buffer.pop(); + } + + let buffer = std::mem::take(&mut self.buffer); + let length = std::str::from_utf8(&buffer) + .map_err(|_| Error::InvalidPacketLength) + .and_then(|s| s.parse::().map_err(|_| Error::InvalidPacketLength))?; + + self.buffer.resize(length, 0); + self.reader.read_exact(&mut self.buffer)?; + + let buffer = std::mem::take(&mut self.buffer); + String::from_utf8(buffer).map_err(Into::into) // TODO: replace 'String::from_utf8' with 'std::str::from_utf8' + }), + Err(e) => Some(Err(Error::Io(e))), + } + } + + #[cfg(feature = "v4")] + fn next_v4(&mut self) -> Option { + match self.reader.read_until(PACKET_SEPARATOR, &mut self.buffer) { + Ok(bytes_read) => { + if bytes_read > 0 { + // remove trailing separator + if self.buffer.ends_with(&[PACKET_SEPARATOR]) { + self.buffer.pop(); + } + + let buffer = std::mem::take(&mut self.buffer); + Some(String::from_utf8(buffer).map_err(Into::into)) // TODO: replace 'String::from_utf8' with 'std::str::from_utf8' + } else { + None + } + } + Err(e) => Some(Err(Error::Io(e))), + } + } +} + +impl Iterator for Payload { + type Item = Item; + + #[cfg(all(feature = "v3", feature = "v4"))] + fn next(&mut self) -> Option { + match self.protocol { + ProtocolVersion::V3 => self.next_v3(), + ProtocolVersion::V4 => self.next_v4(), + } + } + + #[cfg(feature = "v3")] + #[cfg(not(feature = "v4"))] + fn next(&mut self) -> Option { + self.next_v3() + } + + #[cfg(feature = "v4")] + #[cfg(not(feature = "v3"))] + fn next(&mut self) -> Option { + self.next_v4() + } +} + +#[cfg(test)] +mod tests { + use std::{ + io::{BufReader, Cursor}, + vec, + }; + + use crate::service::ProtocolVersion; + + use super::{Payload, PACKET_SEPARATOR}; + + #[test] + fn test_payload_iterator_v4() -> Result<(), String> { + assert!(cfg!(feature = "v4")); + + let data = BufReader::new(Cursor::new(vec![ + b'f', + b'o', + b'o', + PACKET_SEPARATOR, + b'f', + b'o', + PACKET_SEPARATOR, + b'f', + ])); + let mut payload = Payload::new(ProtocolVersion::V4, data); + + assert_eq!(payload.next().unwrap().unwrap(), "foo"); + assert_eq!(payload.next().unwrap().unwrap(), "fo"); + assert_eq!(payload.next().unwrap().unwrap(), "f"); + assert_eq!(payload.next().is_none(), true); + + Ok(()) + } + + #[test] + fn test_payload_iterator_v3() -> Result<(), String> { + assert!(cfg!(feature = "v3")); + + let data = BufReader::new(Cursor::new(vec![ + b'3', b':', b'f', b'o', b'o', b'2', b':', b'f', b'o', b'1', b':', b'f', + ])); + let mut payload = Payload::new(ProtocolVersion::V3, data); + + assert_eq!(payload.next().unwrap().unwrap(), "foo"); + assert_eq!(payload.next().unwrap().unwrap(), "fo"); + assert_eq!(payload.next().unwrap().unwrap(), "f"); + assert_eq!(payload.next().is_none(), true); + + Ok(()) + } +} diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index b5d747ff..c23d67cd 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -2,10 +2,7 @@ use crate::{ body::ResponseBody, config::EngineIoConfig, engine::EngineIo, - errors::{ - Error, - Error::{UnknownTransport, UnsupportedProtocolVersion}, - }, + errors::{Error, Error::UnknownTransport}, futures::ResponseFuture, handler::EngineIoHandler, sid_generator::Sid, @@ -101,25 +98,33 @@ where let engine = self.engine.clone(); match RequestInfo::parse(&req) { Ok(RequestInfo { + protocol, sid: None, transport: TransportType::Polling, method: Method::GET, - }) => ResponseFuture::ready(engine.on_open_http_req(req)), + }) => ResponseFuture::ready(engine.on_open_http_req(protocol, req)), Ok(RequestInfo { + protocol, sid: Some(sid), transport: TransportType::Polling, method: Method::GET, - }) => ResponseFuture::async_response(Box::pin(engine.on_polling_http_req(sid))), + }) => ResponseFuture::async_response(Box::pin( + engine.on_polling_http_req(protocol, sid), + )), Ok(RequestInfo { + protocol, sid: Some(sid), transport: TransportType::Polling, method: Method::POST, - }) => ResponseFuture::async_response(Box::pin(engine.on_post_http_req(sid, req))), + }) => ResponseFuture::async_response(Box::pin( + engine.on_post_http_req(protocol, sid, req), + )), Ok(RequestInfo { + protocol, sid, transport: TransportType::Websocket, method: Method::GET, - }) => ResponseFuture::ready(engine.on_ws_req(sid, req)), + }) => ResponseFuture::ready(engine.on_ws_req(protocol, sid, req)), Err(e) => ResponseFuture::ready(Ok(e.into())), _ => ResponseFuture::empty_response(400), } @@ -207,9 +212,48 @@ impl FromStr for TransportType { } } +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum ProtocolVersion { + V3 = 3, + V4 = 4, +} + +impl FromStr for ProtocolVersion { + type Err = Error; + + #[cfg(all(feature = "v3", feature = "v4"))] + fn from_str(s: &str) -> Result { + match s { + "3" => Ok(ProtocolVersion::V3), + "4" => Ok(ProtocolVersion::V4), + _ => Err(Error::UnsupportedProtocolVersion), + } + } + + #[cfg(feature = "v4")] + #[cfg(not(feature = "v3"))] + fn from_str(s: &str) -> Result { + match s { + "4" => Ok(ProtocolVersion::V4), + _ => Err(Error::UnsupportedProtocolVersion), + } + } + + #[cfg(feature = "v3")] + #[cfg(not(feature = "v4"))] + fn from_str(s: &str) -> Result { + match s { + "3" => Ok(ProtocolVersion::V3), + _ => Err(Error::UnsupportedProtocolVersion), + } + } +} + /// The request information extracted from the request URI. #[derive(Debug)] struct RequestInfo { + /// The protocol version used by the client. + protocol: ProtocolVersion, /// The socket id if present in the request. sid: Option, /// The transport type used by the client. @@ -222,9 +266,13 @@ impl RequestInfo { /// Parse the request URI to extract the [`TransportType`](crate::service::TransportType) and the socket id. fn parse(req: &Request) -> Result { let query = req.uri().query().ok_or(UnknownTransport)?; - if !query.contains("EIO=4") { - return Err(UnsupportedProtocolVersion); - } + + let protocol: ProtocolVersion = query + .split('&') + .find(|s| s.starts_with("EIO=")) + .and_then(|s| s.split('=').nth(1)) + .ok_or(UnknownTransport) + .and_then(|t| t.parse())?; let sid = query .split('&') @@ -238,12 +286,14 @@ impl RequestInfo { .and_then(|s| s.split('=').nth(1)) .ok_or(UnknownTransport) .and_then(|t| t.parse())?; + let method = req.method().clone(); if !matches!(method, Method::GET) && sid.is_none() { Err(Error::BadHandshakeMethod) } else { Ok(RequestInfo { + protocol, sid, transport, method, @@ -266,6 +316,7 @@ mod tests { let info = RequestInfo::parse(&req).unwrap(); assert_eq!(info.sid, None); assert_eq!(info.transport, TransportType::Polling); + assert_eq!(info.protocol, ProtocolVersion::V4); assert_eq!(info.method, Method::GET); } @@ -275,28 +326,31 @@ mod tests { let info = RequestInfo::parse(&req).unwrap(); assert_eq!(info.sid, None); assert_eq!(info.transport, TransportType::Websocket); + assert_eq!(info.protocol, ProtocolVersion::V4); assert_eq!(info.method, Method::GET); } #[test] fn request_info_polling_with_sid() { let req = build_request( - "http://localhost:3000/socket.io/?EIO=4&transport=polling&sid=AAAAAAAAAHs", + "http://localhost:3000/socket.io/?EIO=3&transport=polling&sid=AAAAAAAAAHs", ); let info = RequestInfo::parse(&req).unwrap(); assert_eq!(info.sid, Some(123i64.into())); assert_eq!(info.transport, TransportType::Polling); + assert_eq!(info.protocol, ProtocolVersion::V3); assert_eq!(info.method, Method::GET); } #[test] fn request_info_websocket_with_sid() { let req = build_request( - "http://localhost:3000/socket.io/?EIO=4&transport=websocket&sid=AAAAAAAAAHs", + "http://localhost:3000/socket.io/?EIO=3&transport=websocket&sid=AAAAAAAAAHs", ); let info = RequestInfo::parse(&req).unwrap(); assert_eq!(info.sid, Some(123i64.into())); assert_eq!(info.transport, TransportType::Websocket); + assert_eq!(info.protocol, ProtocolVersion::V3); assert_eq!(info.method, Method::GET); } #[test] diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 8880485f..e6b89c8a 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -16,7 +16,7 @@ use tracing::debug; use crate::sid_generator::Sid; use crate::{ config::EngineIoConfig, errors::Error, handler::EngineIoHandler, packet::Packet, - utils::forward_map_chan, SendPacket, + service::ProtocolVersion, utils::forward_map_chan, SendPacket, }; #[derive(Debug, Clone, PartialEq)] @@ -68,6 +68,9 @@ where /// The socket id pub sid: Sid, + /// The protocol version used by the socket + pub protocol: ProtocolVersion, + /// The connection type represented as a bitfield /// It is represented as a bitfield to allow the use of an [`AtomicU8`] so it can be shared between threads /// without any mutex @@ -86,12 +89,12 @@ where internal_tx: mpsc::Sender, pub tx: mpsc::Sender, - /// Internal channel to receive Pong [`Packets`](Packet) in the heartbeat job + /// Internal channel to receive Pong [`Packets`](Packet) (v4 protocol) or Ping (v3 protocol) in the heartbeat job /// which is running in a separate task - pong_rx: Mutex>, - /// Channel to send Ping [`Packets`](Packet) from the connexion to the heartbeat job + heartbeat_rx: Mutex>, + /// Channel to send Ping [`Packets`](Packet) (v4 protocol) or Ping (v3 protocol) from the connexion to the heartbeat job /// which is running in a separate task - pub(crate) pong_tx: mpsc::Sender<()>, + pub(crate) heartbeat_tx: mpsc::Sender<()>, /// Handle to the heartbeat job so that it can be aborted when the socket is closed heartbeat_handle: Mutex>>, @@ -110,6 +113,7 @@ where { pub(crate) fn new( sid: Sid, + protocol: ProtocolVersion, conn: ConnectionType, config: &EngineIoConfig, req_data: SocketReq, @@ -117,20 +121,21 @@ where ) -> Self { let (internal_tx, internal_rx) = mpsc::channel(config.max_buffer_size); let (tx, rx) = mpsc::channel(config.max_buffer_size); - let (pong_tx, pong_rx) = mpsc::channel(1); + let (heartbeat_tx, heartbeat_rx) = mpsc::channel(1); tokio::spawn(forward_map_chan(rx, internal_tx.clone(), SendPacket::into)); Self { sid, + protocol, conn: AtomicU8::new(conn as u8), internal_rx: Mutex::new(internal_rx), internal_tx, tx, - pong_rx: Mutex::new(pong_rx), - pong_tx, + heartbeat_rx: Mutex::new(heartbeat_rx), + heartbeat_tx, heartbeat_handle: Mutex::new(None), close_fn, @@ -171,14 +176,45 @@ where .replace(handle); } + /// Heartbeat is sent every `interval` milliseconds and the client or server (depending on the protocol) is expected to respond within `timeout` milliseconds. + /// + /// If the client or server does not respond within the timeout, the connection is closed. + #[cfg(all(feature = "v3", feature = "v4"))] + async fn heartbeat_job(&self, interval: Duration, timeout: Duration) -> Result<(), Error> { + match self.protocol { + ProtocolVersion::V3 => self.heartbeat_job_v3(timeout).await, + ProtocolVersion::V4 => self.heartbeat_job_v4(interval, timeout).await, + } + } + + /// Heartbeat is sent every `interval` milliseconds by the client and the server is expected to respond within `timeout` milliseconds. + /// + /// If the client or server does not respond within the timeout, the connection is closed. + #[cfg(feature = "v3")] + #[cfg(not(feature = "v4"))] + async fn heartbeat_job(&self, interval: Duration, timeout: Duration) -> Result<(), Error> { + self.heartbeat_job_v3(timeout) + } + /// Heartbeat is sent every `interval` milliseconds and the client is expected to respond within `timeout` milliseconds. /// /// If the client does not respond within the timeout, the connection is closed. + #[cfg(feature = "v4")] + #[cfg(not(feature = "v3"))] async fn heartbeat_job(&self, interval: Duration, timeout: Duration) -> Result<(), Error> { - let mut pong_rx = self - .pong_rx + self.heartbeat_job_v4(interval, timeout).await + } + + /// Heartbeat is sent every `interval` milliseconds and the client is expected to respond within `timeout` milliseconds. + /// + /// If the client does not respond within the timeout, the connection is closed. + #[cfg(feature = "v4")] + async fn heartbeat_job_v4(&self, interval: Duration, timeout: Duration) -> Result<(), Error> { + let mut heartbeat_rx = self + .heartbeat_rx .try_lock() .expect("Pong rx should be locked only once"); + let instant = tokio::time::Instant::now(); let mut interval_tick = tokio::time::interval(interval); interval_tick.tick().await; @@ -187,15 +223,17 @@ where 15 + instant.elapsed().as_millis() as u64, ))) .await; - debug!("[sid={}] heartbeat routine started", self.sid); + + debug!("[sid={}] heartbeat sender routine started", self.sid); + loop { // Some clients send the pong packet in first. If that happens, we should consume it. - pong_rx.try_recv().ok(); + heartbeat_rx.try_recv().ok(); self.internal_tx .try_send(Packet::Ping) .map_err(|_| Error::HeartbeatTimeout)?; - tokio::time::timeout(timeout, pong_rx.recv()) + tokio::time::timeout(timeout, heartbeat_rx.recv()) .await .map_err(|_| Error::HeartbeatTimeout)? .ok_or(Error::HeartbeatTimeout)?; @@ -203,6 +241,28 @@ where } } + #[cfg(feature = "v3")] + async fn heartbeat_job_v3(&self, timeout: Duration) -> Result<(), Error> { + let mut heartbeat_rx = self + .heartbeat_rx + .try_lock() + .expect("Pong rx should be locked only once"); + + debug!("[sid={}] heartbeat receiver routine started", self.sid); + + loop { + tokio::time::timeout(timeout, heartbeat_rx.recv()) + .await + .map_err(|_| Error::HeartbeatTimeout)? + .ok_or(Error::HeartbeatTimeout)?; + + debug!("[sid={}] ping received, sending pong", self.sid); + self.internal_tx + .try_send(Packet::Pong) + .map_err(|_| Error::HeartbeatTimeout)?; + } + } + /// Returns true if the [`Socket`] has a websocket [`ConnectionType`] pub(crate) fn is_ws(&self) -> bool { self.conn.load(Ordering::Relaxed) == ConnectionType::WebSocket as u8 @@ -245,7 +305,12 @@ where /// /// ⚠️ If the buffer is full or the socket is disconnected, an error will be returned pub fn emit_binary(&self, data: Vec) -> Result<(), Error> { - self.send(Packet::Binary(data))?; + if self.protocol == ProtocolVersion::V3 { + self.send(Packet::BinaryV3(data))?; + } else { + self.send(Packet::Binary(data))?; + } + Ok(()) } } @@ -255,20 +320,21 @@ impl Socket { pub fn new_dummy(sid: Sid, close_fn: Box) -> Socket { let (internal_tx, internal_rx) = mpsc::channel(200); let (tx, rx) = mpsc::channel(200); - let (pong_tx, pong_rx) = mpsc::channel(1); + let (heartbeat_tx, heartbeat_rx) = mpsc::channel(1); tokio::spawn(forward_map_chan(rx, internal_tx.clone(), SendPacket::into)); Self { sid, + protocol: ProtocolVersion::V4, conn: AtomicU8::new(ConnectionType::WebSocket as u8), internal_rx: Mutex::new(internal_rx), internal_tx, tx, - pong_rx: Mutex::new(pong_rx), - pong_tx, + heartbeat_rx: Mutex::new(heartbeat_rx), + heartbeat_tx, heartbeat_handle: Mutex::new(None), close_fn, diff --git a/socketioxide/src/packet.rs b/socketioxide/src/packet.rs index 672a329a..fdc8d1c2 100644 --- a/socketioxide/src/packet.rs +++ b/socketioxide/src/packet.rs @@ -338,7 +338,7 @@ impl TryFrom for Packet { let data = chars.as_str(); let inner = match index { - '0' => PacketData::Connect(deserialize_packet(data)?.unwrap_or(json!({}))), + '0' => PacketData::Connect(deserialize_packet(data)?.unwrap_or_else(|| json!({}))), '1' => PacketData::Disconnect, '2' => { let (event, payload) = deserialize_event_packet(data)?;