From 8058405dcf4b8f5975df5fc6133052bb7330d79d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9odore=20Pr=C3=A9vot?= Date: Sat, 17 Jun 2023 20:13:37 +0200 Subject: [PATCH 01/43] Create CODE_OF_CONDUCT.md --- CODE_OF_CONDUCT.md | 128 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 CODE_OF_CONDUCT.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..e9e2ffe6 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +prevottheodore@gmail.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. From 0a5a393c47e27c5800bb8785f6cf6c434f8a4307 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 18 Jun 2023 13:29:42 +0200 Subject: [PATCH 02/43] Feature: implement EngineIO V3 counting of characters In the v3 protocol, characters are counted instead of using a record separator (\x1e). See: https://github.com/socketio/engine.io-protocol/tree/v3#payload --- engineioxide/src/config.rs | 15 ++++ engineioxide/src/engine.rs | 131 +++++++++++++++++++++++++++++++++-- engineioxide/src/lib.rs | 1 + engineioxide/src/packet.rs | 17 ++++- engineioxide/src/protocol.rs | 19 +++++ engineioxide/src/service.rs | 12 +++- 6 files changed, 187 insertions(+), 8 deletions(-) create mode 100644 engineioxide/src/protocol.rs diff --git a/engineioxide/src/config.rs b/engineioxide/src/config.rs index c6c9dbb3..6ba0180b 100644 --- a/engineioxide/src/config.rs +++ b/engineioxide/src/config.rs @@ -1,5 +1,7 @@ use std::time::Duration; +use crate::protocol::ProtocolVersion; + #[derive(Debug, Clone)] pub struct EngineIoConfig { /// The path to listen for engine.io requests on. @@ -22,6 +24,10 @@ pub struct EngineIoConfig { /// The maximum number of bytes that can be received per http request. /// Defaults to 100kb. pub max_payload: u64, + + /// Protocol version. + /// Supports version `4` (default) and `3`. + pub protocol: ProtocolVersion, } impl Default for EngineIoConfig { @@ -32,6 +38,7 @@ impl Default for EngineIoConfig { ping_timeout: Duration::from_millis(20000), max_buffer_size: 128, max_payload: 1e5 as u64, // 100kb + protocol: ProtocolVersion::V4, } } } @@ -51,6 +58,14 @@ impl EngineIoConfigBuilder { config: EngineIoConfig::default(), } } + + /// The protocol version to use. + /// Defaults to version `4`. + pub fn protocol_version(mut self, protocol: ProtocolVersion) -> Self { + self.config.protocol = protocol; + self + } + /// The path to listen for engine.io requests on. /// Defaults to "/engine.io". pub fn req_path(mut self, req_path: String) -> Self { diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 484c7b14..9dce5d1e 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -5,7 +5,7 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::sid_generator::Sid; +use crate::{sid_generator::Sid, protocol::ProtocolVersion}; use crate::{ body::ResponseBody, config::EngineIoConfig, @@ -26,7 +26,7 @@ use tokio_tungstenite::{ tungstenite::{protocol::Role, Message}, WebSocketStream, }; -use tracing::debug; +use tracing::{debug}; type SocketMap = RwLock>>; /// Abstract engine implementation for Engine.IO server for http polling and websocket @@ -134,7 +134,10 @@ where debug!("sending packet: {:?}", packet); let packet: String = packet.try_into().unwrap(); if !data.is_empty() { - data.push('\x1e'); + match self.config.protocol { + ProtocolVersion::V4 => data.push_str("\x1e"), + ProtocolVersion::V3 => data.push_str(&format!("{}:", packet.chars().count())), + } } data.push_str(&packet); } @@ -166,15 +169,21 @@ where debug!("error aggregating body: {:?}", e); Error::HttpErrorResponse(StatusCode::BAD_REQUEST) })?; - let packets = body.reader().split(b'\x1e'); + + let packets = self.parse_packets(body.reader()).map_err(|e| { + debug!("error parsing packets: {:?}", e); + Error::HttpErrorResponse( + StatusCode::BAD_REQUEST, + ) + })?; let socket = self .get_socket(sid) .and_then(|s| s.is_http().then(|| s)) .ok_or(Error::HttpErrorResponse(StatusCode::BAD_REQUEST))?; - for packet in packets { - let packet = match Packet::try_from(packet?) { + for p in packets { + let packet = match p { Ok(p) => p, Err(e) => { debug!("[sid={sid}] error parsing packet: {:?}", e); @@ -182,6 +191,7 @@ where return Err(Error::HttpErrorResponse(StatusCode::BAD_REQUEST)); } }; + match packet { Packet::Close => { debug!("[sid={sid}] closing session"); @@ -210,6 +220,52 @@ where Ok(http_response(StatusCode::OK, "ok")?) } + fn parse_packets(&self, mut reader: R) -> Result>, String> { + match self.config.protocol { + ProtocolVersion::V4 => { + let raw_packets = reader + .split(b'\x1e') + .map(|e| e.unwrap()); + + let packets = raw_packets.map(|raw_packet| { + Packet::try_from(raw_packet) + }); + + Ok(Box::new(packets) as Box>>) + } + ProtocolVersion::V3 => { + let mut line = String::new(); + let mut i = 0; + + if let Some(bytes_read) = reader.read_line(&mut line).map_err(|e| e.to_string()).ok() { + if bytes_read == 0 { + return Err("no bytes read".into()); + } + } else { + return Err("failed to read line".into()); + } + + let iter = std::iter::from_fn(move || { + while i < line.chars().count() { + if let Some(index) = line.chars().skip(i).collect::().find(':') { + if let Ok(length) = line.chars().take(i + index).skip(i).collect::().parse::() { + let start = i + index + 1; + let end = start + length; + let raw_packet = line.chars().take(end).skip(start).collect::(); + i = end; + return Some(Packet::try_from(raw_packet)); + } + } + break; + } + None + }); + + Ok(Box::new(iter) as Box>>) + }, + } + } + /// Upgrade a websocket request to create a websocket connection. /// /// If a sid is provided in the query it means that is is upgraded from an existing HTTP polling request. In this case @@ -449,3 +505,66 @@ where self.sockets.read().unwrap().get(&sid).cloned() } } + + +#[cfg(test)] +mod tests { + use std::{io::{BufReader, Cursor}}; + + use async_trait::async_trait; + + use super::*; + + #[derive(Debug, Clone)] + struct MockHandler; + + #[async_trait] + impl EngineIoHandler for MockHandler { + type Data = (); + + fn on_connect(&self, socket: &Socket) { + println!("socket connect {}", socket.sid); + } + + fn on_disconnect(&self, socket: &Socket) { + println!("socket disconnect {}", socket.sid); + } + + fn on_message(&self, msg: String, socket: &Socket) { + println!("Ping pong message {:?}", msg); + socket.emit(msg).ok(); + } + + fn on_binary(&self, data: Vec, socket: &Socket) { + println!("Ping pong binary message {:?}", data); + socket.emit_binary(data).ok(); + } + } + + #[test] + fn test_parse_v3_packets() -> Result<(), String> { + let handler = Arc::new(MockHandler); + let config = EngineIoConfig::builder() + .protocol_version(ProtocolVersion::V3) + .build(); + let engine = Arc::new(EngineIo::from_config(handler, config)); + + let mut reader = BufReader::new(Cursor::new("6:4hello2:4€")); + let packets = engine + .parse_packets(reader)?.collect::, Error>>() + .map_err(|e| e.to_string())?; + assert_eq!(packets.len(), 2); + assert_eq!(packets[0], Packet::Message("hello".into())); + assert_eq!(packets[1], Packet::Message("€".into())); + + reader = BufReader::new(Cursor::new("2:4€10:b4AQIDBA==")); + let packets = engine + .parse_packets(reader)?.collect::, Error>>() + .map_err(|e| e.to_string())?; + assert_eq!(packets.len(), 2); + assert_eq!(packets[0], Packet::Message("€".into())); + assert_eq!(packets[1], Packet::Binary(vec![1, 2, 3, 4])); + + Ok(()) + } +} diff --git a/engineioxide/src/lib.rs b/engineioxide/src/lib.rs index d36d737c..17147aa0 100644 --- a/engineioxide/src/lib.rs +++ b/engineioxide/src/lib.rs @@ -10,6 +10,7 @@ pub mod layer; pub mod service; pub mod sid_generator; pub mod socket; +pub mod protocol; mod body; mod engine; diff --git a/engineioxide/src/packet.rs b/engineioxide/src/packet.rs index 01376dee..90c7d352 100644 --- a/engineioxide/src/packet.rs +++ b/engineioxide/src/packet.rs @@ -100,7 +100,15 @@ impl TryFrom for Packet { '4' => Packet::Message(packet_data.to_string()), '5' => Packet::Upgrade, '6' => Packet::Noop, - 'b' => Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?), + 'b' => { + let mut packet_data = packet_data; + + if packet_data.starts_with('4') { + packet_data = &packet_data[1..]; + } + + Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?) + }, c => Err(serde_json::Error::custom( "Invalid packet type ".to_string() + &c.to_string(), ))?, @@ -109,6 +117,13 @@ impl TryFrom for Packet { } } +impl TryFrom<&str> for Packet { + type Error = crate::errors::Error; + fn try_from(value: &str) -> Result { + Packet::try_from(value.to_string()) + } +} + /// Deserialize a Binary Packet variant from a [Vec] according to the Engine.IO protocol /// Used when receiving data from a websocket binary frame impl TryFrom> for Packet { diff --git a/engineioxide/src/protocol.rs b/engineioxide/src/protocol.rs new file mode 100644 index 00000000..18d43a41 --- /dev/null +++ b/engineioxide/src/protocol.rs @@ -0,0 +1,19 @@ +use std::str::FromStr; + +#[derive(Debug, Clone, PartialEq)] +pub enum ProtocolVersion { + V3 = 3, + V4 = 4, +} + +impl FromStr for ProtocolVersion { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "3" => Ok(ProtocolVersion::V3), + "4" => Ok(ProtocolVersion::V4), + _ => Err(()), + } + } +} diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index c98b775a..a470320b 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -1,3 +1,4 @@ +use crate::protocol::ProtocolVersion; use crate::sid_generator::Sid; use crate::{ body::ResponseBody, config::EngineIoConfig, engine::EngineIo, futures::ResponseFuture, @@ -231,7 +232,16 @@ impl RequestInfo { /// Parse the request URI to extract the [`TransportType`](crate::service::TransportType) and the socket id. fn parse(req: &Request) -> Option { let query = req.uri().query()?; - if !query.contains("EIO=4") { + + let protocol: ProtocolVersion = query + .split('&') + .find(|s| s.starts_with("EIO="))? + .split('=') + .nth(1)? + .parse() + .ok()?; + + if protocol != ProtocolVersion::V4 && protocol != ProtocolVersion::V3 { return None; } From 9024f40ddb8958b65ea85b72359b6cc8eea0615e Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 18 Jun 2023 18:08:52 +0200 Subject: [PATCH 03/43] Feat: determine protocol version from URL query --- engineioxide/src/config.rs | 14 -------------- engineioxide/src/engine.rs | 22 +++++++++++----------- engineioxide/src/service.rs | 15 +++++++++------ 3 files changed, 20 insertions(+), 31 deletions(-) diff --git a/engineioxide/src/config.rs b/engineioxide/src/config.rs index 6ba0180b..4c972e76 100644 --- a/engineioxide/src/config.rs +++ b/engineioxide/src/config.rs @@ -1,7 +1,5 @@ use std::time::Duration; -use crate::protocol::ProtocolVersion; - #[derive(Debug, Clone)] pub struct EngineIoConfig { /// The path to listen for engine.io requests on. @@ -24,10 +22,6 @@ pub struct EngineIoConfig { /// The maximum number of bytes that can be received per http request. /// Defaults to 100kb. pub max_payload: u64, - - /// Protocol version. - /// Supports version `4` (default) and `3`. - pub protocol: ProtocolVersion, } impl Default for EngineIoConfig { @@ -38,7 +32,6 @@ impl Default for EngineIoConfig { ping_timeout: Duration::from_millis(20000), max_buffer_size: 128, max_payload: 1e5 as u64, // 100kb - protocol: ProtocolVersion::V4, } } } @@ -59,13 +52,6 @@ impl EngineIoConfigBuilder { } } - /// The protocol version to use. - /// Defaults to version `4`. - pub fn protocol_version(mut self, protocol: ProtocolVersion) -> Self { - self.config.protocol = protocol; - self - } - /// The path to listen for engine.io requests on. /// Defaults to "/engine.io". pub fn req_path(mut self, req_path: String) -> Self { diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 9dce5d1e..c99b79e2 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -104,6 +104,7 @@ where /// Otherwise it will wait for the next packet to be sent from the socket pub(crate) async fn on_polling_http_req( self: Arc, + protocol: ProtocolVersion, sid: Sid, ) -> Result>, Error> where @@ -134,7 +135,7 @@ where debug!("sending packet: {:?}", packet); let packet: String = packet.try_into().unwrap(); if !data.is_empty() { - match self.config.protocol { + match protocol { ProtocolVersion::V4 => data.push_str("\x1e"), ProtocolVersion::V3 => data.push_str(&format!("{}:", packet.chars().count())), } @@ -156,6 +157,7 @@ where /// Split the body into packets and send them to the internal socket pub(crate) async fn on_post_http_req( self: Arc, + protocol: ProtocolVersion, sid: Sid, body: Request, ) -> Result>, Error> @@ -170,7 +172,7 @@ where Error::HttpErrorResponse(StatusCode::BAD_REQUEST) })?; - let packets = self.parse_packets(body.reader()).map_err(|e| { + let packets = self.parse_packets(body.reader(), protocol).map_err(|e| { debug!("error parsing packets: {:?}", e); Error::HttpErrorResponse( StatusCode::BAD_REQUEST, @@ -220,8 +222,8 @@ where Ok(http_response(StatusCode::OK, "ok")?) } - fn parse_packets(&self, mut reader: R) -> Result>, String> { - match self.config.protocol { + fn parse_packets(&self, mut reader: R, protocol: ProtocolVersion) -> Result>, String> { + match protocol { ProtocolVersion::V4 => { let raw_packets = reader .split(b'\x1e') @@ -543,15 +545,13 @@ mod tests { #[test] fn test_parse_v3_packets() -> Result<(), String> { + let protocol = ProtocolVersion::V3; let handler = Arc::new(MockHandler); - let config = EngineIoConfig::builder() - .protocol_version(ProtocolVersion::V3) - .build(); - let engine = Arc::new(EngineIo::from_config(handler, config)); - + let engine = Arc::new(EngineIo::new(handler)); + let mut reader = BufReader::new(Cursor::new("6:4hello2:4€")); let packets = engine - .parse_packets(reader)?.collect::, Error>>() + .parse_packets(reader, protocol.clone())?.collect::, Error>>() .map_err(|e| e.to_string())?; assert_eq!(packets.len(), 2); assert_eq!(packets[0], Packet::Message("hello".into())); @@ -559,7 +559,7 @@ mod tests { reader = BufReader::new(Cursor::new("2:4€10:b4AQIDBA==")); let packets = engine - .parse_packets(reader)?.collect::, Error>>() + .parse_packets(reader, protocol.clone())?.collect::, Error>>() .map_err(|e| e.to_string())?; assert_eq!(packets.len(), 2); assert_eq!(packets[0], Packet::Message("€".into())); diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index a470320b..f451b5a0 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -109,21 +109,25 @@ where let engine = self.engine.clone(); match RequestInfo::parse(&req) { Some(RequestInfo { + protocol: _, sid: None, transport: TransportType::Polling, method: Method::GET, }) => ResponseFuture::ready(engine.on_open_http_req(req)), Some(RequestInfo { + protocol, sid: Some(sid), transport: TransportType::Polling, method: Method::GET, - }) => ResponseFuture::async_response(Box::pin(engine.on_polling_http_req(sid))), + }) => ResponseFuture::async_response(Box::pin(engine.on_polling_http_req(protocol, sid))), Some(RequestInfo { + protocol, sid: Some(sid), transport: TransportType::Polling, method: Method::POST, - }) => ResponseFuture::async_response(Box::pin(engine.on_post_http_req(sid, req))), + }) => ResponseFuture::async_response(Box::pin(engine.on_post_http_req(protocol, sid, req))), Some(RequestInfo { + protocol: _, sid, transport: TransportType::Websocket, method: Method::GET, @@ -220,6 +224,8 @@ impl FromStr for TransportType { /// The request information extracted from the request URI. struct RequestInfo { + /// The protocol version used by the client. + protocol: ProtocolVersion, /// The socket id if present in the request. sid: Option, /// The transport type used by the client. @@ -241,10 +247,6 @@ impl RequestInfo { .parse() .ok()?; - if protocol != ProtocolVersion::V4 && protocol != ProtocolVersion::V3 { - return None; - } - let sid = query .split('&') .find(|s| s.starts_with("sid=")) @@ -260,6 +262,7 @@ impl RequestInfo { .ok()?; Some(RequestInfo { + protocol, sid, transport, method: req.method().clone(), From 3841cc247309d833ef01c995015b04a3854925d1 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 18 Jun 2023 18:14:50 +0200 Subject: [PATCH 04/43] Refactor: use position method to find index --- engineioxide/src/engine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index c99b79e2..654f3218 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -249,7 +249,7 @@ where let iter = std::iter::from_fn(move || { while i < line.chars().count() { - if let Some(index) = line.chars().skip(i).collect::().find(':') { + if let Some(index) = line.chars().skip(i).position(|c| c == ':') { if let Ok(length) = line.chars().take(i + index).skip(i).collect::().parse::() { let start = i + index + 1; let end = start + length; From a116a85622956f6949ce00551606ebbef5c7612a Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 18 Jun 2023 18:57:45 +0200 Subject: [PATCH 05/43] Refactor: improve check for V3 binary packet --- engineioxide/src/packet.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engineioxide/src/packet.rs b/engineioxide/src/packet.rs index 90c7d352..779a2eb7 100644 --- a/engineioxide/src/packet.rs +++ b/engineioxide/src/packet.rs @@ -103,8 +103,8 @@ impl TryFrom for Packet { 'b' => { let mut packet_data = packet_data; - if packet_data.starts_with('4') { - packet_data = &packet_data[1..]; + if value.starts_with("b4") { + packet_data = &packet_data[1..]; } Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?) From bccd62ee7e85580585e65f4f0c42ff60d4df2755 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 18 Jun 2023 21:15:48 +0200 Subject: [PATCH 06/43] Feat: add missing character counter --- engineioxide/src/engine.rs | 6 +++++- engineioxide/src/service.rs | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 654f3218..8dd025a4 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -69,6 +69,7 @@ where /// Send an open packet pub(crate) fn on_open_http_req( self: Arc, + protocol: ProtocolVersion, req: Request, ) -> Result>, Error> where @@ -94,7 +95,10 @@ where self.handler.on_connect(&socket); let packet = OpenPacket::new(TransportType::Polling, sid, &self.config); - let packet: String = Packet::Open(packet).try_into()?; + let mut packet: String = Packet::Open(packet).try_into()?; + if protocol == ProtocolVersion::V3 { + packet = format!("{}:{}", packet.chars().count(), packet); + } http_response(StatusCode::OK, packet).map_err(Error::Http) } diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index f451b5a0..76ab6cc8 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -109,11 +109,11 @@ where let engine = self.engine.clone(); match RequestInfo::parse(&req) { Some(RequestInfo { - protocol: _, + protocol, sid: None, transport: TransportType::Polling, method: Method::GET, - }) => ResponseFuture::ready(engine.on_open_http_req(req)), + }) => ResponseFuture::ready(engine.on_open_http_req(protocol, req)), Some(RequestInfo { protocol, sid: Some(sid), From eaab5e47578b817d33228f2c01a19cfa7958e647 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 25 Jun 2023 12:36:00 +0200 Subject: [PATCH 07/43] Feat: implement reverse ping/pong for V3 --- engineioxide/src/engine.rs | 49 ++++++++++++++++-------- engineioxide/src/service.rs | 4 +- engineioxide/src/socket.rs | 76 +++++++++++++++++++++++++------------ 3 files changed, 87 insertions(+), 42 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 8dd025a4..07c776e4 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -26,7 +26,7 @@ use tokio_tungstenite::{ tungstenite::{protocol::Role, Message}, WebSocketStream, }; -use tracing::{debug}; +use tracing::debug; type SocketMap = RwLock>>; /// Abstract engine implementation for Engine.IO server for http polling and websocket @@ -91,7 +91,7 @@ where } socket .clone() - .spawn_heartbeat(self.config.ping_interval, self.config.ping_timeout); + .spawn_heartbeat(protocol.clone(), self.config.ping_interval, self.config.ping_timeout); self.handler.on_connect(&socket); let packet = OpenPacket::new(TransportType::Polling, sid, &self.config); @@ -151,6 +151,9 @@ where if data.is_empty() { let packet = rx.recv().await.ok_or(Error::Aborted)?; let packet: String = packet.try_into().unwrap(); + if protocol == ProtocolVersion::V3 { + data.push_str(&format!("{}:", packet.chars().count())); + } data.push_str(&packet); } Ok(http_response(StatusCode::OK, data)?) @@ -205,10 +208,12 @@ where self.close_session(sid); break; } - Packet::Pong => socket - .pong_tx - .try_send(()) - .map_err(|_| Error::HeartbeatTimeout), + Packet::Pong | Packet::Ping => { + socket + .pong_tx + .try_send(()) + .map_err(|_| Error::HeartbeatTimeout) + }, Packet::Message(msg) => { self.handler.on_message(msg, &socket); Ok(()) @@ -278,6 +283,7 @@ where /// the http polling request is closed and the SID is kept for the websocket pub(crate) fn on_ws_req( self: Arc, + protocol: ProtocolVersion, sid: Option, req: Request, ) -> Result>, Error> { @@ -292,7 +298,7 @@ where let req = Request::from_parts(parts, ()); tokio::spawn(async move { match hyper::upgrade::on(req).await { - Ok(conn) => match self.on_ws_req_init(conn, sid, req_data).await { + Ok(conn) => match self.on_ws_req_init(conn, protocol, sid, req_data).await { Ok(_) => debug!("ws closed"), Err(e) => debug!("ws closed with error: {:?}", e), }, @@ -311,6 +317,7 @@ where async fn on_ws_req_init( self: Arc, conn: Upgraded, + protocol: ProtocolVersion, sid: Option, req_data: SocketReq, ) -> Result<(), Error> { @@ -335,7 +342,7 @@ where self.ws_init_handshake(sid, &mut ws).await?; socket .clone() - .spawn_heartbeat(self.config.ping_interval, self.config.ping_timeout); + .spawn_heartbeat(protocol.clone(), self.config.ping_interval, self.config.ping_timeout); socket } else { let sid = sid.unwrap(); @@ -359,7 +366,10 @@ where Packet::Binary(bin) => tx.send(Message::Binary(bin)).await, Packet::Close => tx.send(Message::Close(None)).await, _ => { - let packet: String = item.try_into().unwrap(); + let mut packet: String = item.try_into().unwrap(); + if protocol.clone() == ProtocolVersion::V3 { + packet = format!("{}:{}", packet.chars().count(), packet) + } tx.send(Message::Text(packet)).await } }; @@ -389,16 +399,18 @@ where while let Ok(msg) = rx.try_next().await { let Some(msg) = msg else { continue }; match msg { - Message::Text(msg) => match Packet::try_from(msg)? { + Message::Text(msg) => match Packet::try_from(msg.clone())? { Packet::Close => { debug!("[sid={}] closing session", socket.sid); self.close_session(socket.sid); break; } - Packet::Pong => socket - .pong_tx - .try_send(()) - .map_err(|_| Error::HeartbeatTimeout), + Packet::Pong | Packet::Ping => { + socket + .pong_tx + .try_send(()) + .map_err(|_| Error::HeartbeatTimeout) + }, Packet::Message(msg) => { self.handler.on_message(msg, socket); Ok(()) @@ -465,7 +477,14 @@ where // Fetch the next packet from the ws stream, it should be an Upgrade packet let msg = match ws.next().await { Some(Ok(Message::Text(d))) => d, - _ => Err(Error::UpgradeError)?, + Some(Ok(Message::Close(_))) => { + debug!("ws stream closed before upgrade"); + Err(Error::UpgradeError)? + }, + _ => { + debug!("unexpected ws message before upgrade"); + Err(Error::UpgradeError)? + }, }; match Packet::try_from(msg)? { Packet::Upgrade => debug!("[sid={sid}] ws upgraded successful"), diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index 76ab6cc8..2347c262 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -127,11 +127,11 @@ where method: Method::POST, }) => ResponseFuture::async_response(Box::pin(engine.on_post_http_req(protocol, sid, req))), Some(RequestInfo { - protocol: _, + protocol, sid, transport: TransportType::Websocket, method: Method::GET, - }) => ResponseFuture::ready(engine.on_ws_req(sid, req)), + }) => ResponseFuture::ready(engine.on_ws_req(protocol, sid, req)), _ => ResponseFuture::empty_response(400), } } else { diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 8880485f..700cb303 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -13,11 +13,11 @@ use tokio::{ }; use tracing::debug; -use crate::sid_generator::Sid; use crate::{ config::EngineIoConfig, errors::Error, handler::EngineIoHandler, packet::Packet, utils::forward_map_chan, SendPacket, }; +use crate::{protocol::ProtocolVersion, sid_generator::Sid}; #[derive(Debug, Clone, PartialEq)] pub(crate) enum ConnectionType { @@ -156,11 +156,16 @@ where /// Spawn the heartbeat job /// /// Keep a handle to the job so that it can be aborted when the socket is closed - pub(crate) fn spawn_heartbeat(self: Arc, interval: Duration, timeout: Duration) { + pub(crate) fn spawn_heartbeat( + self: Arc, + protocol: ProtocolVersion, + interval: Duration, + timeout: Duration, + ) { let socket = self.clone(); let handle = tokio::spawn(async move { - if let Err(e) = socket.heartbeat_job(interval, timeout).await { + if let Err(e) = socket.heartbeat_job(protocol, interval, timeout).await { socket.close(); debug!("[sid={}] heartbeat error: {:?}", socket.sid, e); } @@ -174,32 +179,53 @@ where /// Heartbeat is sent every `interval` milliseconds and the client is expected to respond within `timeout` milliseconds. /// /// If the client does not respond within the timeout, the connection is closed. - async fn heartbeat_job(&self, interval: Duration, timeout: Duration) -> Result<(), Error> { + async fn heartbeat_job( + &self, + protocol: ProtocolVersion, + interval: Duration, + timeout: Duration, + ) -> Result<(), Error> { let mut pong_rx = self .pong_rx .try_lock() .expect("Pong rx should be locked only once"); - let instant = tokio::time::Instant::now(); - let mut interval_tick = tokio::time::interval(interval); - interval_tick.tick().await; - // Sleep for an interval minus the time it took to get here - tokio::time::sleep(interval.saturating_sub(Duration::from_millis( - 15 + instant.elapsed().as_millis() as u64, - ))) - .await; - debug!("[sid={}] heartbeat routine started", self.sid); - loop { - // Some clients send the pong packet in first. If that happens, we should consume it. - pong_rx.try_recv().ok(); - - self.internal_tx - .try_send(Packet::Ping) - .map_err(|_| Error::HeartbeatTimeout)?; - tokio::time::timeout(timeout, pong_rx.recv()) - .await - .map_err(|_| Error::HeartbeatTimeout)? - .ok_or(Error::HeartbeatTimeout)?; - interval_tick.tick().await; + + match protocol { + ProtocolVersion::V3 => { + debug!("[sid={}] heartbeat receiver routine started", self.sid); + loop { + if pong_rx.recv().await.is_some() { + debug!("[sid={}] ping received, sending pong", self.sid); + self.internal_tx + .try_send(Packet::Pong) + .map_err(|_| Error::HeartbeatTimeout)?; + } + } + } + ProtocolVersion::V4 => { + let instant = tokio::time::Instant::now(); + let mut interval_tick = tokio::time::interval(interval); + interval_tick.tick().await; + // Sleep for an interval minus the time it took to get here + tokio::time::sleep(interval.saturating_sub(Duration::from_millis( + 15 + instant.elapsed().as_millis() as u64, + ))) + .await; + debug!("[sid={}] heartbeat sender routine started", self.sid); + loop { + // Some clients send the pong packet in first. If that happens, we should consume it. + pong_rx.try_recv().ok(); + + self.internal_tx + .try_send(Packet::Ping) + .map_err(|_| Error::HeartbeatTimeout)?; + tokio::time::timeout(timeout, pong_rx.recv()) + .await + .map_err(|_| Error::HeartbeatTimeout)? + .ok_or(Error::HeartbeatTimeout)?; + interval_tick.tick().await; + } + } } } From a26e73a282eccd9a9f466cab1ca46c5a192de439 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Mon, 26 Jun 2023 19:07:54 +0200 Subject: [PATCH 08/43] Fix: implement missing paused polling transport The V3 protocol handles transport upgrades slightly different. This commit solves an issue where the transport would be forever stuck in a polling state, even though the client thinks it has already been upgraded to websocket. --- engineioxide/src/engine.rs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 6842ef17..faf3f0ec 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -316,7 +316,7 @@ impl EngineIo Some(_) => { debug!("[sid={sid}] websocket connection upgrade"); let mut ws = ws_init().await; - self.ws_upgrade_handshake(sid, &mut ws).await?; + self.ws_upgrade_handshake(protocol.clone(), sid, &mut ws).await?; (self.get_socket(sid).unwrap(), ws) } } @@ -355,7 +355,7 @@ impl EngineIo Packet::Close => tx.send(Message::Close(None)).await, _ => { let mut packet: String = item.try_into().unwrap(); - if protocol.clone() == ProtocolVersion::V3 { + if protocol == ProtocolVersion::V3 { packet = format!("{}:{}", packet.chars().count(), packet) } tx.send(Message::Text(packet)).await @@ -441,12 +441,16 @@ impl EngineIo /// ``` async fn ws_upgrade_handshake( &self, + protocol: ProtocolVersion, sid: Sid, ws: &mut WebSocketStream, ) -> Result<(), Error> { let socket = self.get_socket(sid).unwrap(); - // send a NOOP packet to any pending polling request so it closes gracefully - socket.send(Packet::Noop)?; + + // send a NOOP packet to any pending polling request so it closes gracefully' + if protocol == ProtocolVersion::V4 { + socket.send(Packet::Noop)?; + } // Fetch the next packet from the ws stream, it should be a PingUpgrade packet let msg = match ws.next().await { @@ -462,6 +466,13 @@ impl EngineIo p => Err(Error::BadPacket(p))?, }; + // send a NOOP packet to any pending polling request so it closes gracefully + // V3 protocol introduce _paused_ polling transport which require to close + // the polling request **after** the ping/pong handshake + if protocol == ProtocolVersion::V3 { + socket.send(Packet::Noop)?; + } + // Fetch the next packet from the ws stream, it should be an Upgrade packet let msg = match ws.next().await { Some(Ok(Message::Text(d))) => d, From 3b502f47c5c3b16bb809fa46cd83e51931c18bad Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Mon, 26 Jun 2023 21:36:32 +0200 Subject: [PATCH 09/43] Fix: don't wrap piped packet in character counter --- engineioxide/src/engine.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index faf3f0ec..f10f4099 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -354,10 +354,7 @@ impl EngineIo Packet::Binary(bin) => tx.send(Message::Binary(bin)).await, Packet::Close => tx.send(Message::Close(None)).await, _ => { - let mut packet: String = item.try_into().unwrap(); - if protocol == ProtocolVersion::V3 { - packet = format!("{}:{}", packet.chars().count(), packet) - } + let packet: String = item.try_into().unwrap(); tx.send(Message::Text(packet)).await } }; From e5e1db623309075ee757c28e6f60640114509aa2 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Wed, 28 Jun 2023 12:23:56 +0200 Subject: [PATCH 10/43] Fix: prefix character counter to single packet --- engineioxide/src/engine.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index f10f4099..05e655fb 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -130,6 +130,8 @@ impl EngineIo ProtocolVersion::V4 => data.push_str("\x1e"), ProtocolVersion::V3 => data.push_str(&format!("{}:", packet.chars().count())), } + } else if protocol == ProtocolVersion::V3 { + data.push_str(&format!("{}:", packet.chars().count())); } data.push_str(&packet); } From 711811f3e85aaaae5a21ddd59eeedf53b411ac1d Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Wed, 28 Jun 2023 22:17:40 +0200 Subject: [PATCH 11/43] Refactor: move payload parsing to separate file --- engineioxide/src/engine.rs | 98 ++++---------------------- engineioxide/src/lib.rs | 1 + engineioxide/src/payload.rs | 132 ++++++++++++++++++++++++++++++++++++ 3 files changed, 145 insertions(+), 86 deletions(-) create mode 100644 engineioxide/src/payload.rs diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 05e655fb..c05e4742 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -1,11 +1,10 @@ #![deny(clippy::await_holding_lock)] use std::{ collections::HashMap, - io::BufRead, sync::{Arc, RwLock}, }; -use crate::{sid_generator::Sid, protocol::ProtocolVersion}; +use crate::{sid_generator::Sid, protocol::ProtocolVersion, payload::Payload}; use crate::{ body::ResponseBody, config::EngineIoConfig, @@ -167,21 +166,23 @@ impl EngineIo debug!("error aggregating body: {:?}", e); Error::HttpErrorResponse(StatusCode::BAD_REQUEST) })?; - - let packets = self.parse_packets(body.reader(), protocol).map_err(|e| { - debug!("error parsing packets: {:?}", e); - Error::HttpErrorResponse( - StatusCode::BAD_REQUEST, - ) - })?; - + let socket = self .get_socket(sid) .ok_or(Error::UnknownSessionID(sid)) .and_then(|s| s.is_http().then(|| s).ok_or(Error::TransportMismatch))?; + let packets = Payload::new(body.reader(), protocol); + for p in packets { - let packet = match p { + let raw_packet = p.map_err(|e| { + debug!("error parsing packets: {:?}", e); + Error::HttpErrorResponse( + StatusCode::BAD_REQUEST, + ) + })?; + + let packet = match Packet::try_from(raw_packet) { Ok(p) => p, Err(e) => { debug!("[sid={sid}] error parsing packet: {:?}", e); @@ -220,52 +221,6 @@ impl EngineIo Ok(http_response(StatusCode::OK, "ok")?) } - fn parse_packets(&self, mut reader: R, protocol: ProtocolVersion) -> Result>, String> { - match protocol { - ProtocolVersion::V4 => { - let raw_packets = reader - .split(b'\x1e') - .map(|e| e.unwrap()); - - let packets = raw_packets.map(|raw_packet| { - Packet::try_from(raw_packet) - }); - - Ok(Box::new(packets) as Box>>) - } - ProtocolVersion::V3 => { - let mut line = String::new(); - let mut i = 0; - - if let Some(bytes_read) = reader.read_line(&mut line).map_err(|e| e.to_string()).ok() { - if bytes_read == 0 { - return Err("no bytes read".into()); - } - } else { - return Err("failed to read line".into()); - } - - let iter = std::iter::from_fn(move || { - while i < line.chars().count() { - if let Some(index) = line.chars().skip(i).position(|c| c == ':') { - if let Ok(length) = line.chars().take(i + index).skip(i).collect::().parse::() { - let start = i + index + 1; - let end = start + length; - let raw_packet = line.chars().take(end).skip(start).collect::(); - i = end; - return Some(Packet::try_from(raw_packet)); - } - } - break; - } - None - }); - - Ok(Box::new(iter) as Box>>) - }, - } - } - /// Upgrade a websocket request to create a websocket connection. /// /// If a sid is provided in the query it means that is is upgraded from an existing HTTP polling request. In this case @@ -532,8 +487,6 @@ impl EngineIo #[cfg(test)] mod tests { - use std::{io::{BufReader, Cursor}}; - use async_trait::async_trait; use super::*; @@ -563,31 +516,4 @@ mod tests { socket.emit_binary(data).ok(); } } - - #[test] - fn test_parse_v3_packets() -> Result<(), String> { - let protocol = ProtocolVersion::V3; - let handler = MockHandler; - let engine = EngineIo::new(handler, Default::default()); - - let mut reader = BufReader::new(Cursor::new("6:4hello2:4€")); - let packets = engine - .parse_packets(reader, protocol.clone())? - .collect::, Error>>() - .map_err(|e| e.to_string())?; - assert_eq!(packets.len(), 2); - assert_eq!(packets[0], Packet::Message("hello".into())); - assert_eq!(packets[1], Packet::Message("€".into())); - - reader = BufReader::new(Cursor::new("2:4€10:b4AQIDBA==")); - let packets = engine - .parse_packets(reader, protocol.clone())? - .collect::, Error>>() - .map_err(|e| e.to_string())?; - assert_eq!(packets.len(), 2); - assert_eq!(packets[0], Packet::Message("€".into())); - assert_eq!(packets[1], Packet::Binary(vec![1, 2, 3, 4])); - - Ok(()) - } } diff --git a/engineioxide/src/lib.rs b/engineioxide/src/lib.rs index 17147aa0..15e7f9a5 100644 --- a/engineioxide/src/lib.rs +++ b/engineioxide/src/lib.rs @@ -17,3 +17,4 @@ mod engine; mod futures; mod packet; mod utils; +mod payload; diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs new file mode 100644 index 00000000..805e80fb --- /dev/null +++ b/engineioxide/src/payload.rs @@ -0,0 +1,132 @@ +use std::{io::BufRead, vec}; + +use crate::protocol::ProtocolVersion; + +const PACKET_SEPARATOR: u8 = b'\x1e'; + +/// A payload is a series of encoded packets tied together. +/// How packets are tied together depends on the protocol. +pub struct Payload { + reader: R, + buffer: Vec, + protocol: ProtocolVersion, +} + +impl Payload { + pub fn new(data: R, protocol: ProtocolVersion) -> Self { + Payload { + reader: data, + buffer: vec![], + protocol, + } + } +} + +impl Iterator for Payload { + type Item = Result; + + fn next(&mut self) -> Option { + self.buffer.clear(); + + match self.protocol { + ProtocolVersion::V3 => { + match self.reader.read_until(b':', &mut self.buffer) { + Ok(bytes_read) => { + if bytes_read > 0 { + // remove trailing separator + if self.buffer.ends_with(&[b':']) { + self.buffer.pop(); + } + + let length = match String::from_utf8(self.buffer.clone()) { + Ok(s) => { + if let Ok(l) = s.parse::() { + l + } else { + return Some(Err("Invalid packet length".into())); + } + }, + Err(_) => return Some(Err("Invalid packet length".into())), + }; + + self.buffer.clear(); + self.buffer.resize(length, 0); + + match self.reader.read_exact(&mut self.buffer) { + Ok(_) => { + match String::from_utf8(self.buffer.clone()) { + Ok(s) => Some(Ok(s)), + Err(_) => Some(Err("Invalid packet data".into())), + } + }, + Err(err) => Some(Err(err.to_string())), + } + } else { + None + } + } + Err(err) => Some(Err(err.to_string())), + } + }, + ProtocolVersion::V4 => { + match self.reader.read_until(PACKET_SEPARATOR, &mut self.buffer) { + Ok(bytes_read) => { + if bytes_read > 0 { + // remove trailing separator + if self.buffer.ends_with(&[PACKET_SEPARATOR]) { + self.buffer.pop(); + } + + match String::from_utf8( self.buffer.clone()) { + Ok(s) => Some(Ok(s)), + Err(_) => Some(Err("Packet is not a valid UTF-8 string".into())), + } + } else { + None + } + } + Err(err) => Some(Err(err.to_string())), + } + }, + } + } +} + +#[cfg(test)] +mod tests { + use std::{io::{BufReader, Cursor}, vec}; + + use crate::protocol::ProtocolVersion; + + use super::{Payload, PACKET_SEPARATOR}; + + #[test] + fn test_payload_iterator_v4() -> Result<(), String> { + let data = BufReader::new(Cursor::new(vec![ + b'f', b'o', b'o', PACKET_SEPARATOR, b'f', b'o', PACKET_SEPARATOR, b'f', + ])); + let mut payload = Payload::new(data, ProtocolVersion::V4); + + assert_eq!(payload.next(), Some(Ok("foo".into()))); + assert_eq!(payload.next(), Some(Ok("fo".into()))); + assert_eq!(payload.next(), Some(Ok("f".into()))); + assert_eq!(payload.next(), None); + + Ok(()) + } + + #[test] + fn test_payload_iterator_v3() -> Result<(), String> { + let data = BufReader::new(Cursor::new(vec![ + b'3', b':', b'f', b'o', b'o', b'2', b':', b'f', b'o', b'1', b':', b'f', + ])); + let mut payload = Payload::new(data, ProtocolVersion::V3); + + assert_eq!(payload.next(), Some(Ok("foo".into()))); + assert_eq!(payload.next(), Some(Ok("fo".into()))); + assert_eq!(payload.next(), Some(Ok("f".into()))); + assert_eq!(payload.next(), None); + + Ok(()) + } +} From 9e2645a07c9242a794b7b9bec21c08decc3cef86 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Wed, 28 Jun 2023 22:17:56 +0200 Subject: [PATCH 12/43] Fix: close session on invalid packet --- engineioxide/src/engine.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index c05e4742..41674424 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -177,6 +177,7 @@ impl EngineIo for p in packets { let raw_packet = p.map_err(|e| { debug!("error parsing packets: {:?}", e); + self.close_session(sid); Error::HttpErrorResponse( StatusCode::BAD_REQUEST, ) From 2c3feb45ec7d3cead831d75900de05cce5d092e6 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Thu, 29 Jun 2023 23:12:45 +0200 Subject: [PATCH 13/43] Feat: add v3 protocol ping/pong timeout --- engineioxide/src/socket.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 700cb303..1649ee39 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -194,12 +194,15 @@ where ProtocolVersion::V3 => { debug!("[sid={}] heartbeat receiver routine started", self.sid); loop { - if pong_rx.recv().await.is_some() { - debug!("[sid={}] ping received, sending pong", self.sid); - self.internal_tx - .try_send(Packet::Pong) - .map_err(|_| Error::HeartbeatTimeout)?; - } + tokio::time::timeout(timeout, pong_rx.recv()) + .await + .map_err(|_| Error::HeartbeatTimeout)? + .ok_or(Error::HeartbeatTimeout)?; + + debug!("[sid={}] ping received, sending pong", self.sid); + self.internal_tx + .try_send(Packet::Pong) + .map_err(|_| Error::HeartbeatTimeout)?; } } ProtocolVersion::V4 => { From 7738b416170b64a6637c69df955ba3bf3f2bc2cf Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sat, 1 Jul 2023 12:41:58 +0200 Subject: [PATCH 14/43] Refactor: remove unused TryFrom implementation --- engineioxide/src/packet.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/engineioxide/src/packet.rs b/engineioxide/src/packet.rs index 779a2eb7..6d201d24 100644 --- a/engineioxide/src/packet.rs +++ b/engineioxide/src/packet.rs @@ -117,13 +117,6 @@ impl TryFrom for Packet { } } -impl TryFrom<&str> for Packet { - type Error = crate::errors::Error; - fn try_from(value: &str) -> Result { - Packet::try_from(value.to_string()) - } -} - /// Deserialize a Binary Packet variant from a [Vec] according to the Engine.IO protocol /// Used when receiving data from a websocket binary frame impl TryFrom> for Packet { From a88a0b0a37f64d666a47dd8e19671effd5f54d50 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sat, 1 Jul 2023 14:17:17 +0200 Subject: [PATCH 15/43] Feat: enable protocol versions through features --- .github/workflows/socketio-ci.yml | 2 +- Cargo.lock | 1 + engineioxide/Cargo.toml | 8 +- engineioxide/Readme.md | 20 ++++ engineioxide/src/engine.rs | 2 +- engineioxide/src/payload.rs | 155 ++++++++++++++++++------------ engineioxide/src/socket.rs | 139 +++++++++++++++++++-------- 7 files changed, 222 insertions(+), 105 deletions(-) diff --git a/.github/workflows/socketio-ci.yml b/.github/workflows/socketio-ci.yml index 0a71f721..16088fcb 100644 --- a/.github/workflows/socketio-ci.yml +++ b/.github/workflows/socketio-ci.yml @@ -18,7 +18,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: toolchain: stable - - run: cargo test + - run: cargo test --all-features e2e: runs-on: ubuntu-latest diff --git a/Cargo.lock b/Cargo.lock index c2da666c..0c49a003 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -358,6 +358,7 @@ dependencies = [ "base64 0.21.0", "base64id", "bytes", + "cfg-if", "criterion", "futures", "http", diff --git a/engineioxide/Cargo.toml b/engineioxide/Cargo.toml index 1256242e..f26b400c 100644 --- a/engineioxide/Cargo.toml +++ b/engineioxide/Cargo.toml @@ -30,10 +30,16 @@ tower = "0.4.13" tracing = "0.1.37" rand = "0.8.5" base64id = { version = "0.3.1", features = ["std", "rand", "serde"] } +cfg-if = "1.0.0" [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } [[bench]] name = "benchmark_polling" -harness = false \ No newline at end of file +harness = false + +[features] +default = ["v4"] +v4 = [] +v3 = [] diff --git a/engineioxide/Readme.md b/engineioxide/Readme.md index a5b9ce7b..e5f360ef 100644 --- a/engineioxide/Readme.md +++ b/engineioxide/Readme.md @@ -47,3 +47,23 @@ async fn main() -> Result<(), Box> { Ok(()) } ``` + +### Supported Protocols +You can enable support for other EngineIO protocol implementations through feature flags. +The latest supported protocol version is enabled by default. + +To add support for another protocol version, adjust your dependency configuration accordingly: + +```toml +[dependencies] +# Enables the `v3` protocol (`v4` is also implicitly enabled, as it's the default). +engineioxide = { version = "0.3.0", features = ["v3"] } +``` + +To enable *a single protocol version only*, disable default features: + +```toml +[dependencies] +# Enables the `v3` protocol only. +engineioxide = { version = "0.3.0", features = ["v3"], default-features = false } +``` diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 41674424..0889eca7 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -172,7 +172,7 @@ impl EngineIo .ok_or(Error::UnknownSessionID(sid)) .and_then(|s| s.is_http().then(|| s).ok_or(Error::TransportMismatch))?; - let packets = Payload::new(body.reader(), protocol); + let packets = Payload::new(protocol, body.reader()); for p in packets { let raw_packet = p.map_err(|e| { diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index 805e80fb..b0e2a21c 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -1,7 +1,9 @@ use std::{io::BufRead, vec}; +use cfg_if::cfg_if; use crate::protocol::ProtocolVersion; +#[cfg(feature = "v4")] const PACKET_SEPARATOR: u8 = b'\x1e'; /// A payload is a series of encoded packets tied together. @@ -9,85 +11,112 @@ const PACKET_SEPARATOR: u8 = b'\x1e'; pub struct Payload { reader: R, buffer: Vec, + #[allow(dead_code)] protocol: ProtocolVersion, } +type Item = Result; + impl Payload { - pub fn new(data: R, protocol: ProtocolVersion) -> Self { + pub fn new(protocol: ProtocolVersion, data: R) -> Self { Payload { reader: data, buffer: vec![], protocol, } } -} - -impl Iterator for Payload { - type Item = Result; - fn next(&mut self) -> Option { + #[cfg(feature = "v3")] + fn next_v3(&mut self) -> Option { self.buffer.clear(); - match self.protocol { - ProtocolVersion::V3 => { - match self.reader.read_until(b':', &mut self.buffer) { - Ok(bytes_read) => { - if bytes_read > 0 { - // remove trailing separator - if self.buffer.ends_with(&[b':']) { - self.buffer.pop(); + match self.reader.read_until(b':', &mut self.buffer) { + Ok(bytes_read) => { + if bytes_read > 0 { + // remove trailing separator + if self.buffer.ends_with(&[b':']) { + self.buffer.pop(); + } + + let length = match String::from_utf8(self.buffer.clone()) { + Ok(s) => { + if let Ok(l) = s.parse::() { + l + } else { + return Some(Err("Invalid packet length".into())); } + }, + Err(_) => return Some(Err("Invalid packet length".into())), + }; + + self.buffer.clear(); + self.buffer.resize(length, 0); - let length = match String::from_utf8(self.buffer.clone()) { - Ok(s) => { - if let Ok(l) = s.parse::() { - l - } else { - return Some(Err("Invalid packet length".into())); - } - }, - Err(_) => return Some(Err("Invalid packet length".into())), - }; - - self.buffer.clear(); - self.buffer.resize(length, 0); - - match self.reader.read_exact(&mut self.buffer) { - Ok(_) => { - match String::from_utf8(self.buffer.clone()) { - Ok(s) => Some(Ok(s)), - Err(_) => Some(Err("Invalid packet data".into())), - } - }, - Err(err) => Some(Err(err.to_string())), + match self.reader.read_exact(&mut self.buffer) { + Ok(_) => { + match String::from_utf8(self.buffer.clone()) { + Ok(s) => Some(Ok(s)), + Err(_) => Some(Err("Invalid packet data".into())), } - } else { - None - } + }, + Err(err) => Some(Err(err.to_string())), } - Err(err) => Some(Err(err.to_string())), + } else { + None } - }, - ProtocolVersion::V4 => { - match self.reader.read_until(PACKET_SEPARATOR, &mut self.buffer) { - Ok(bytes_read) => { - if bytes_read > 0 { - // remove trailing separator - if self.buffer.ends_with(&[PACKET_SEPARATOR]) { - self.buffer.pop(); - } - - match String::from_utf8( self.buffer.clone()) { - Ok(s) => Some(Ok(s)), - Err(_) => Some(Err("Packet is not a valid UTF-8 string".into())), - } - } else { - None - } + } + Err(err) => Some(Err(err.to_string())), + } + } + + #[cfg(feature = "v4")] + fn next_v4(&mut self) -> Option { + self.buffer.clear(); + + match self.reader.read_until(PACKET_SEPARATOR, &mut self.buffer) { + Ok(bytes_read) => { + if bytes_read > 0 { + // remove trailing separator + if self.buffer.ends_with(&[PACKET_SEPARATOR]) { + self.buffer.pop(); } - Err(err) => Some(Err(err.to_string())), + + match String::from_utf8( self.buffer.clone()) { + Ok(s) => Some(Ok(s)), + Err(_) => Some(Err("Packet is not a valid UTF-8 string".into())), + } + } else { + None + } + } + Err(err) => Some(Err(err.to_string())), + } + } +} + +impl Iterator for Payload { + type Item = Item; + + cfg_if! { + if #[cfg(all(feature = "v3", feature = "v4"))] { + fn next(&mut self) -> Option { + match self.protocol { + ProtocolVersion::V3 => { + self.next_v3() + }, + ProtocolVersion::V4 => { + self.next_v4() + }, } - }, + } + } else if #[cfg(feature = "v3")] { + fn next(&mut self) -> Option { + self.next_v3() + } + } else { + fn next(&mut self) -> Option { + self.next_v4() + } } } } @@ -102,10 +131,12 @@ mod tests { #[test] fn test_payload_iterator_v4() -> Result<(), String> { + assert!(cfg!(feature = "v4")); + let data = BufReader::new(Cursor::new(vec![ b'f', b'o', b'o', PACKET_SEPARATOR, b'f', b'o', PACKET_SEPARATOR, b'f', ])); - let mut payload = Payload::new(data, ProtocolVersion::V4); + let mut payload = Payload::new(ProtocolVersion::V4, data); assert_eq!(payload.next(), Some(Ok("foo".into()))); assert_eq!(payload.next(), Some(Ok("fo".into()))); @@ -117,10 +148,12 @@ mod tests { #[test] fn test_payload_iterator_v3() -> Result<(), String> { + assert!(cfg!(feature = "v3")); + let data = BufReader::new(Cursor::new(vec![ b'3', b':', b'f', b'o', b'o', b'2', b':', b'f', b'o', b'1', b':', b'f', ])); - let mut payload = Payload::new(data, ProtocolVersion::V3); + let mut payload = Payload::new(ProtocolVersion::V3, data); assert_eq!(payload.next(), Some(Ok("foo".into()))); assert_eq!(payload.next(), Some(Ok("fo".into()))); diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 1649ee39..b54b4020 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -6,6 +6,7 @@ use std::{ time::Duration, }; +use cfg_if::cfg_if; use http::{request::Parts, Uri}; use tokio::{ sync::{mpsc, mpsc::Receiver, Mutex}, @@ -176,12 +177,59 @@ where .replace(handle); } + cfg_if! { + if #[cfg(all(feature = "v3", feature = "v4"))] { + /// Heartbeat is sent every `interval` milliseconds and the client or server (depending on the protocol) is expected to respond within `timeout` milliseconds. + /// + /// If the client or server does not respond within the timeout, the connection is closed. + async fn heartbeat_job( + &self, + protocol: ProtocolVersion, + interval: Duration, + timeout: Duration, + ) -> Result<(), Error> { + match protocol { + ProtocolVersion::V3 => { + self.heartbeat_job_v3(timeout).await + } + ProtocolVersion::V4 => { + self.heartbeat_job_v4(interval, timeout).await + } + } + } + } else if #[cfg(feature = "v3")] { + /// Heartbeat is sent every `interval` milliseconds by the client and the server is expected to respond within `timeout` milliseconds. + /// + /// If the client or server does not respond within the timeout, the connection is closed. + async fn heartbeat_job( + &self, + _: ProtocolVersion, + interval: Duration, + timeout: Duration, + ) -> Result<(), Error> { + self.heartbeat_job_v3(timeout) + } + } else { + /// Heartbeat is sent every `interval` milliseconds and the client is expected to respond within `timeout` milliseconds. + /// + /// If the client does not respond within the timeout, the connection is closed. + async fn heartbeat_job( + &self, + _: ProtocolVersion, + interval: Duration, + timeout: Duration, + ) -> Result<(), Error> { + self.heartbeat_job_v4(interval, timeout).await + } + } + } + /// Heartbeat is sent every `interval` milliseconds and the client is expected to respond within `timeout` milliseconds. /// /// If the client does not respond within the timeout, the connection is closed. - async fn heartbeat_job( + #[cfg(feature = "v4")] + async fn heartbeat_job_v4( &self, - protocol: ProtocolVersion, interval: Duration, timeout: Duration, ) -> Result<(), Error> { @@ -190,45 +238,54 @@ where .try_lock() .expect("Pong rx should be locked only once"); - match protocol { - ProtocolVersion::V3 => { - debug!("[sid={}] heartbeat receiver routine started", self.sid); - loop { - tokio::time::timeout(timeout, pong_rx.recv()) - .await - .map_err(|_| Error::HeartbeatTimeout)? - .ok_or(Error::HeartbeatTimeout)?; - - debug!("[sid={}] ping received, sending pong", self.sid); - self.internal_tx - .try_send(Packet::Pong) - .map_err(|_| Error::HeartbeatTimeout)?; - } - } - ProtocolVersion::V4 => { - let instant = tokio::time::Instant::now(); - let mut interval_tick = tokio::time::interval(interval); - interval_tick.tick().await; - // Sleep for an interval minus the time it took to get here - tokio::time::sleep(interval.saturating_sub(Duration::from_millis( - 15 + instant.elapsed().as_millis() as u64, - ))) - .await; - debug!("[sid={}] heartbeat sender routine started", self.sid); - loop { - // Some clients send the pong packet in first. If that happens, we should consume it. - pong_rx.try_recv().ok(); - - self.internal_tx - .try_send(Packet::Ping) - .map_err(|_| Error::HeartbeatTimeout)?; - tokio::time::timeout(timeout, pong_rx.recv()) - .await - .map_err(|_| Error::HeartbeatTimeout)? - .ok_or(Error::HeartbeatTimeout)?; - interval_tick.tick().await; - } - } + let instant = tokio::time::Instant::now(); + let mut interval_tick = tokio::time::interval(interval); + interval_tick.tick().await; + // Sleep for an interval minus the time it took to get here + tokio::time::sleep(interval.saturating_sub(Duration::from_millis( + 15 + instant.elapsed().as_millis() as u64, + ))) + .await; + + debug!("[sid={}] heartbeat sender routine started", self.sid); + + loop { + // Some clients send the pong packet in first. If that happens, we should consume it. + pong_rx.try_recv().ok(); + + self.internal_tx + .try_send(Packet::Ping) + .map_err(|_| Error::HeartbeatTimeout)?; + tokio::time::timeout(timeout, pong_rx.recv()) + .await + .map_err(|_| Error::HeartbeatTimeout)? + .ok_or(Error::HeartbeatTimeout)?; + interval_tick.tick().await; + } + } + + #[cfg(feature = "v3")] + async fn heartbeat_job_v3( + &self, + timeout: Duration, + ) -> Result<(), Error> { + let mut pong_rx = self + .pong_rx + .try_lock() + .expect("Pong rx should be locked only once"); + + debug!("[sid={}] heartbeat receiver routine started", self.sid); + + loop { + tokio::time::timeout(timeout, pong_rx.recv()) + .await + .map_err(|_| Error::HeartbeatTimeout)? + .ok_or(Error::HeartbeatTimeout)?; + + debug!("[sid={}] ping received, sending pong", self.sid); + self.internal_tx + .try_send(Packet::Pong) + .map_err(|_| Error::HeartbeatTimeout)?; } } From 73c3e1d1b81ef446c0dfdb27403277a35bcf60f4 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sat, 1 Jul 2023 15:29:08 +0200 Subject: [PATCH 16/43] Refactor: eliminate more unused code based on protocol version --- engineioxide/src/engine.rs | 65 +++++++++++++++++++++++++++---------- engineioxide/src/service.rs | 23 ++++++++++++- 2 files changed, 70 insertions(+), 18 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 0889eca7..5ab138cf 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -17,6 +17,7 @@ use crate::{ socket::{ConnectionType, Socket, SocketReq}, }; use bytes::Buf; +use cfg_if::cfg_if; use futures::{stream::SplitStream, SinkExt, StreamExt, TryStreamExt}; use http::{Request, Response, StatusCode}; use hyper::upgrade::Upgraded; @@ -83,9 +84,14 @@ impl EngineIo self.handler.on_connect(&socket); let packet = OpenPacket::new(TransportType::Polling, sid, &self.config); + #[allow(unused_mut)] let mut packet: String = Packet::Open(packet).try_into()?; - if protocol == ProtocolVersion::V3 { - packet = format!("{}:{}", packet.chars().count(), packet); + cfg_if! { + if #[cfg(feature = "v3")] { + if protocol == ProtocolVersion::V3 { + packet = format!("{}:{}", packet.chars().count(), packet); + } + } } http_response(StatusCode::OK, packet).map_err(Error::Http) } @@ -125,12 +131,25 @@ impl EngineIo debug!("sending packet: {:?}", packet); let packet: String = packet.try_into().unwrap(); if !data.is_empty() { - match protocol { - ProtocolVersion::V4 => data.push_str("\x1e"), - ProtocolVersion::V3 => data.push_str(&format!("{}:", packet.chars().count())), + cfg_if! { + if #[cfg(feature = "v3")] { + if protocol == ProtocolVersion::V3 { + data.push_str(&format!("{}:", packet.chars().count())); + } + } + else if #[cfg(feature = "v4")] { + if protocol == ProtocolVersion::V4 { + data.push_str("\x1e"); + } + } + } + } + cfg_if! { + if #[cfg(feature = "v3")] { + if data.is_empty() && protocol == ProtocolVersion::V3 { + data.push_str(&format!("{}:", packet.chars().count())); + } } - } else if protocol == ProtocolVersion::V3 { - data.push_str(&format!("{}:", packet.chars().count())); } data.push_str(&packet); } @@ -139,8 +158,12 @@ impl EngineIo if data.is_empty() { let packet = rx.recv().await.ok_or(Error::Aborted)?; let packet: String = packet.try_into().unwrap(); - if protocol == ProtocolVersion::V3 { - data.push_str(&format!("{}:", packet.chars().count())); + cfg_if! { + if #[cfg(feature = "v3")] { + if protocol == ProtocolVersion::V3 { + data.push_str(&format!("{}:", packet.chars().count())); + } + } } data.push_str(&packet); } @@ -402,9 +425,13 @@ impl EngineIo ) -> Result<(), Error> { let socket = self.get_socket(sid).unwrap(); - // send a NOOP packet to any pending polling request so it closes gracefully' - if protocol == ProtocolVersion::V4 { - socket.send(Packet::Noop)?; + cfg_if! { + if #[cfg(feature = "v4")] { + // send a NOOP packet to any pending polling request so it closes gracefully' + if protocol == ProtocolVersion::V4 { + socket.send(Packet::Noop)?; + } + } } // Fetch the next packet from the ws stream, it should be a PingUpgrade packet @@ -421,11 +448,15 @@ impl EngineIo p => Err(Error::BadPacket(p))?, }; - // send a NOOP packet to any pending polling request so it closes gracefully - // V3 protocol introduce _paused_ polling transport which require to close - // the polling request **after** the ping/pong handshake - if protocol == ProtocolVersion::V3 { - socket.send(Packet::Noop)?; + cfg_if! { + if #[cfg(feature = "v3")] { + // send a NOOP packet to any pending polling request so it closes gracefully + // V3 protocol introduce _paused_ polling transport which require to close + // the polling request **after** the ping/pong handshake + if protocol == ProtocolVersion::V3 { + socket.send(Packet::Noop)?; + } + } } // Fetch the next packet from the ws stream, it should be an Upgrade packet diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index 9dab8439..de116309 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -8,9 +8,10 @@ use crate::{ }, futures::ResponseFuture, handler::EngineIoHandler, - sid_generator::Sid, protocol::{ProtocolVersion}, + sid_generator::Sid, protocol::ProtocolVersion, }; use bytes::Bytes; +use cfg_if::cfg_if; use futures::future::{ready, Ready}; use http::{Method, Request}; use http_body::{Body, Empty}; @@ -236,6 +237,26 @@ impl RequestInfo { .ok_or(UnknownTransport) .and_then(|t| t.parse())?; + cfg_if! { + if #[cfg(all(feature = "v3", feature = "v4"))] { + if protocol != ProtocolVersion::V3 && protocol != ProtocolVersion::V4 { + return Err(Error::UnsupportedProtocolVersion); + } + } + else if #[cfg(feature = "v4")] { + if protocol != ProtocolVersion::V4 { + return Err(Error::UnsupportedProtocolVersion); + } + } + else if #[cfg(feature = "v3")] { + if protocol != ProtocolVersion::V3 { + return Err(Error::UnsupportedProtocolVersion); + } + } else { + compile_error!("At least one protocol version must be enabled"); + } + } + let sid = query .split('&') .find(|s| s.starts_with("sid=")) From d62cff2d4efda0375e3214077775f17d360fdfab Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sat, 1 Jul 2023 23:31:37 +0200 Subject: [PATCH 17/43] Feat: add special packet type for v3 binary data The v3 protocol requires not only the 'b' prefix for binary data, but also a '4' for the 'message' type. Thus, I've added a special `BinaryV3` enum field to handle this very specifc packet type. I know this looks a bit hacky, but it's the simplest solution I could come up with in regards to the way things are currently implemented. Besides, this is the only exception in all protocol versions where the message is encoded slightly different so imo it warrants the exception for the time being. --- engineioxide/src/engine.rs | 10 ++++++---- engineioxide/src/packet.rs | 32 +++++++++++++++++++++++++++----- engineioxide/src/socket.rs | 21 ++++++++++++++------- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 5ab138cf..06a1e598 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -69,6 +69,7 @@ impl EngineIo let sid = generate_sid(); let socket = Socket::new( sid, + protocol.clone(), ConnectionType::Http, &self.config, SocketReq::from(req.into_parts().0), @@ -80,7 +81,7 @@ impl EngineIo } socket .clone() - .spawn_heartbeat(protocol.clone(), self.config.ping_interval, self.config.ping_timeout); + .spawn_heartbeat(self.config.ping_interval, self.config.ping_timeout); self.handler.on_connect(&socket); let packet = OpenPacket::new(TransportType::Polling, sid, &self.config); @@ -232,7 +233,7 @@ impl EngineIo self.handler.on_message(msg, &socket); Ok(()) } - Packet::Binary(bin) => { + Packet::Binary(bin) | Packet::BinaryV3(bin) => { self.handler.on_binary(bin, &socket); Ok(()) } @@ -307,6 +308,7 @@ impl EngineIo let close_fn = Box::new(move |sid: Sid| engine.close_session(sid)); let socket = Socket::new( sid, + protocol.clone(), ConnectionType::WebSocket, &self.config, req_data, @@ -321,7 +323,7 @@ impl EngineIo self.ws_init_handshake(sid, &mut ws).await?; socket .clone() - .spawn_heartbeat(protocol.clone(), self.config.ping_interval, self.config.ping_timeout); + .spawn_heartbeat(self.config.ping_interval, self.config.ping_timeout); (socket, ws) }; let (mut tx, rx) = ws.split(); @@ -332,7 +334,7 @@ impl EngineIo let mut socket_rx = rx_socket.internal_rx.try_lock().unwrap(); while let Some(item) = socket_rx.recv().await { let res = match item { - Packet::Binary(bin) => tx.send(Message::Binary(bin)).await, + Packet::Binary(bin) | Packet::BinaryV3(bin) => tx.send(Message::Binary(bin)).await, Packet::Close => tx.send(Message::Close(None)).await, _ => { let packet: String = item.try_into().unwrap(); diff --git a/engineioxide/src/packet.rs b/engineioxide/src/packet.rs index 6d201d24..bda99963 100644 --- a/engineioxide/src/packet.rs +++ b/engineioxide/src/packet.rs @@ -47,6 +47,15 @@ pub enum Packet { /// /// When receiving, it is only used with polling connection, websocket use binary frame Binary(Vec), // Not part of the protocol, used internally + + /// Binary packet used to send binary data to the client + /// Converts to a String using base64 encoding when using polling connection + /// Or to a websocket binary frame when using websocket connection + /// + /// When receiving, it is only used with polling connection, websocket use binary frame + /// + /// This is a special packet, excepionally specific to the V3 protocol. + BinaryV3(Vec), // Not part of the protocol, used internally } /// Serialize a [Packet] to a [String] according to the Engine.IO protocol @@ -66,6 +75,7 @@ impl TryInto for Packet { Packet::Upgrade => "5".to_string(), Packet::Noop => "6".to_string(), Packet::Binary(data) => "b".to_string() + &general_purpose::STANDARD.encode(data), + Packet::BinaryV3(data) => "b4".to_string() + &general_purpose::STANDARD.encode(data), }; Ok(res) } @@ -101,13 +111,11 @@ impl TryFrom for Packet { '5' => Packet::Upgrade, '6' => Packet::Noop, 'b' => { - let mut packet_data = packet_data; - if value.starts_with("b4") { - packet_data = &packet_data[1..]; + Packet::BinaryV3(general_purpose::STANDARD.decode(&packet_data[1..].as_bytes())?) + } else { + Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?) } - - Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?) }, c => Err(serde_json::Error::custom( "Invalid packet type ".to_string() + &c.to_string(), @@ -239,6 +247,20 @@ mod tests { assert_eq!(packet, Packet::Binary(vec![1, 2, 3])); } + #[test] + fn test_binary_packet_v3() { + let packet = Packet::BinaryV3(vec![1, 2, 3]); + let packet_str: String = packet.try_into().unwrap(); + assert_eq!(packet_str, "b4AQID"); + } + + #[test] + fn test_binary_packet_v3_deserialize() { + let packet_str = "b4AQID".to_string(); + let packet: Packet = packet_str.try_into().unwrap(); + assert_eq!(packet, Packet::BinaryV3(vec![1, 2, 3])); + } + #[test] fn test_send_packet_into_packet() { let packet = SendPacket::Message("hello".to_string()); diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index b54b4020..6973d95a 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -69,6 +69,9 @@ where /// The socket id pub sid: Sid, + /// The protocol version used by the socket + pub protocol: ProtocolVersion, + /// The connection type represented as a bitfield /// It is represented as a bitfield to allow the use of an [`AtomicU8`] so it can be shared between threads /// without any mutex @@ -111,6 +114,7 @@ where { pub(crate) fn new( sid: Sid, + protocol: ProtocolVersion, conn: ConnectionType, config: &EngineIoConfig, req_data: SocketReq, @@ -124,6 +128,7 @@ where Self { sid, + protocol, conn: AtomicU8::new(conn as u8), internal_rx: Mutex::new(internal_rx), @@ -159,14 +164,13 @@ where /// Keep a handle to the job so that it can be aborted when the socket is closed pub(crate) fn spawn_heartbeat( self: Arc, - protocol: ProtocolVersion, interval: Duration, timeout: Duration, ) { let socket = self.clone(); let handle = tokio::spawn(async move { - if let Err(e) = socket.heartbeat_job(protocol, interval, timeout).await { + if let Err(e) = socket.heartbeat_job(interval, timeout).await { socket.close(); debug!("[sid={}] heartbeat error: {:?}", socket.sid, e); } @@ -184,11 +188,10 @@ where /// If the client or server does not respond within the timeout, the connection is closed. async fn heartbeat_job( &self, - protocol: ProtocolVersion, interval: Duration, timeout: Duration, ) -> Result<(), Error> { - match protocol { + match self.protocol { ProtocolVersion::V3 => { self.heartbeat_job_v3(timeout).await } @@ -203,7 +206,6 @@ where /// If the client or server does not respond within the timeout, the connection is closed. async fn heartbeat_job( &self, - _: ProtocolVersion, interval: Duration, timeout: Duration, ) -> Result<(), Error> { @@ -215,7 +217,6 @@ where /// If the client does not respond within the timeout, the connection is closed. async fn heartbeat_job( &self, - _: ProtocolVersion, interval: Duration, timeout: Duration, ) -> Result<(), Error> { @@ -331,7 +332,12 @@ where /// /// ⚠️ If the buffer is full or the socket is disconnected, an error will be returned pub fn emit_binary(&self, data: Vec) -> Result<(), Error> { - self.send(Packet::Binary(data))?; + if self.protocol == ProtocolVersion::V3 { + self.send(Packet::BinaryV3(data))?; + } else { + self.send(Packet::Binary(data))?; + } + Ok(()) } } @@ -347,6 +353,7 @@ impl Socket { Self { sid, + protocol: ProtocolVersion::V4, conn: AtomicU8::new(ConnectionType::WebSocket as u8), internal_rx: Mutex::new(internal_rx), From c59bdd52db4c36f08570f859d0bad005f2a73621 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sat, 1 Jul 2023 23:54:07 +0200 Subject: [PATCH 18/43] Fix: remove overkill compile time checks These compile time checks were incorrectly added and breaking v4 tests. Given a second thought they don't bring major improvements to the code base and are hard to read so I decided to remove them. --- engineioxide/src/engine.rs | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 06a1e598..964e8348 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -132,25 +132,12 @@ impl EngineIo debug!("sending packet: {:?}", packet); let packet: String = packet.try_into().unwrap(); if !data.is_empty() { - cfg_if! { - if #[cfg(feature = "v3")] { - if protocol == ProtocolVersion::V3 { - data.push_str(&format!("{}:", packet.chars().count())); - } - } - else if #[cfg(feature = "v4")] { - if protocol == ProtocolVersion::V4 { - data.push_str("\x1e"); - } - } - } - } - cfg_if! { - if #[cfg(feature = "v3")] { - if data.is_empty() && protocol == ProtocolVersion::V3 { - data.push_str(&format!("{}:", packet.chars().count())); - } + match protocol { + ProtocolVersion::V3 => data.push_str(&format!("{}:", packet.chars().count())), + ProtocolVersion::V4 => data.push_str("\x1e"), } + } else if protocol == ProtocolVersion::V3 { + data.push_str(&format!("{}:", packet.chars().count())); } data.push_str(&packet); } From 8df81f42b1a6a3efc1600396a02a23e5d4f30fc9 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 00:15:17 +0200 Subject: [PATCH 19/43] Test: add engineio V3 test to CI --- .github/workflows/engineio-ci.yml | 40 +++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 .github/workflows/engineio-ci.yml diff --git a/.github/workflows/engineio-ci.yml b/.github/workflows/engineio-ci.yml new file mode 100644 index 00000000..f6d1fb20 --- /dev/null +++ b/.github/workflows/engineio-ci.yml @@ -0,0 +1,40 @@ + +name: EngineIO CI + +on: + push: + tags: + - v* + pull_request: + branches: + - main + - develop + +jobs: + e2e_v3: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + - uses: actions/checkout@v3 + with: + # TODO: replace with official repo once https://github.com/socketio/engine.io-protocol/pull/45 has been merged + repository: sleeyax/engine.io-protocol + path: engine.io-protocol + ref: v3-testsuite + - uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Install deps & run tests + run: | + cd engine.io-protocol/test-suite && npm install && cd ../.. + cargo run --bin engineioxide-e2e --release --features v3 --no-default-features > v3_server.txt & npm --prefix engine.io-protocol/test-suite test > v3_client.txt + - name: Server output + if: always() + run: cat v3_server.txt + - name: Client output + if: always() + run: cat v3_client.txt \ No newline at end of file From d4bfefc1d11d5d4eb0f8da96088dd749ff06933b Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 00:25:08 +0200 Subject: [PATCH 20/43] Refactor: add clippy suggestions --- engineioxide/src/engine.rs | 4 ++-- engineioxide/src/packet.rs | 4 ++-- engineioxide/src/payload.rs | 3 +-- socketioxide/src/packet.rs | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 964e8348..b3b7834c 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -4,7 +4,7 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::{sid_generator::Sid, protocol::ProtocolVersion, payload::Payload}; +use crate::{sid_generator::Sid, protocol::ProtocolVersion, payload::{Payload, PACKET_SEPARATOR}}; use crate::{ body::ResponseBody, config::EngineIoConfig, @@ -134,7 +134,7 @@ impl EngineIo if !data.is_empty() { match protocol { ProtocolVersion::V3 => data.push_str(&format!("{}:", packet.chars().count())), - ProtocolVersion::V4 => data.push_str("\x1e"), + ProtocolVersion::V4 => data.push(std::char::from_u32(PACKET_SEPARATOR as u32).unwrap()), } } else if protocol == ProtocolVersion::V3 { data.push_str(&format!("{}:", packet.chars().count())); diff --git a/engineioxide/src/packet.rs b/engineioxide/src/packet.rs index bda99963..5c0122f4 100644 --- a/engineioxide/src/packet.rs +++ b/engineioxide/src/packet.rs @@ -85,7 +85,7 @@ impl TryFrom for Packet { type Error = crate::errors::Error; fn try_from(value: String) -> Result { let mut chars = value.chars(); - let packet_type = chars.next().ok_or(serde_json::Error::custom( + let packet_type = chars.next().ok_or_else(|| serde_json::Error::custom( "Packet type not found in packet string", ))?; let packet_data = chars.as_str(); @@ -112,7 +112,7 @@ impl TryFrom for Packet { '6' => Packet::Noop, 'b' => { if value.starts_with("b4") { - Packet::BinaryV3(general_purpose::STANDARD.decode(&packet_data[1..].as_bytes())?) + Packet::BinaryV3(general_purpose::STANDARD.decode(packet_data[1..].as_bytes())?) } else { Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?) } diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index b0e2a21c..478a9518 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -3,8 +3,7 @@ use cfg_if::cfg_if; use crate::protocol::ProtocolVersion; -#[cfg(feature = "v4")] -const PACKET_SEPARATOR: u8 = b'\x1e'; +pub const PACKET_SEPARATOR: u8 = b'\x1e'; /// A payload is a series of encoded packets tied together. /// How packets are tied together depends on the protocol. diff --git a/socketioxide/src/packet.rs b/socketioxide/src/packet.rs index 672a329a..fdc8d1c2 100644 --- a/socketioxide/src/packet.rs +++ b/socketioxide/src/packet.rs @@ -338,7 +338,7 @@ impl TryFrom for Packet { let data = chars.as_str(); let inner = match index { - '0' => PacketData::Connect(deserialize_packet(data)?.unwrap_or(json!({}))), + '0' => PacketData::Connect(deserialize_packet(data)?.unwrap_or_else(|| json!({}))), '1' => PacketData::Disconnect, '2' => { let (event, payload) = deserialize_event_packet(data)?; From 3c0700018fd78ccefd2c97d5d5d9aed6fd31241f Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 09:51:48 +0200 Subject: [PATCH 21/43] Refactor: move protocol enum to service mod --- engineioxide/src/engine.rs | 2 +- engineioxide/src/lib.rs | 1 - engineioxide/src/payload.rs | 4 ++-- engineioxide/src/protocol.rs | 21 --------------------- engineioxide/src/service.rs | 20 +++++++++++++++++++- engineioxide/src/socket.rs | 4 ++-- 6 files changed, 24 insertions(+), 28 deletions(-) delete mode 100644 engineioxide/src/protocol.rs diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index b3b7834c..ae1e9079 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -4,7 +4,7 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::{sid_generator::Sid, protocol::ProtocolVersion, payload::{Payload, PACKET_SEPARATOR}}; +use crate::{sid_generator::Sid, payload::{Payload, PACKET_SEPARATOR}, service::ProtocolVersion}; use crate::{ body::ResponseBody, config::EngineIoConfig, diff --git a/engineioxide/src/lib.rs b/engineioxide/src/lib.rs index 15e7f9a5..dfa01102 100644 --- a/engineioxide/src/lib.rs +++ b/engineioxide/src/lib.rs @@ -10,7 +10,6 @@ pub mod layer; pub mod service; pub mod sid_generator; pub mod socket; -pub mod protocol; mod body; mod engine; diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index 478a9518..b41eb143 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -1,7 +1,7 @@ use std::{io::BufRead, vec}; use cfg_if::cfg_if; -use crate::protocol::ProtocolVersion; +use crate::service::ProtocolVersion; pub const PACKET_SEPARATOR: u8 = b'\x1e'; @@ -124,7 +124,7 @@ impl Iterator for Payload { mod tests { use std::{io::{BufReader, Cursor}, vec}; - use crate::protocol::ProtocolVersion; + use crate::service::ProtocolVersion; use super::{Payload, PACKET_SEPARATOR}; diff --git a/engineioxide/src/protocol.rs b/engineioxide/src/protocol.rs deleted file mode 100644 index 6a22c2ce..00000000 --- a/engineioxide/src/protocol.rs +++ /dev/null @@ -1,21 +0,0 @@ -use std::str::FromStr; - -use crate::errors::Error; - -#[derive(Debug, Clone, PartialEq)] -pub enum ProtocolVersion { - V3 = 3, - V4 = 4, -} - -impl FromStr for ProtocolVersion { - type Err = Error; - - fn from_str(s: &str) -> Result { - match s { - "3" => Ok(ProtocolVersion::V3), - "4" => Ok(ProtocolVersion::V4), - _ => Err(Error::UnsupportedProtocolVersion), - } - } -} diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index de116309..64ea013a 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -8,7 +8,7 @@ use crate::{ }, futures::ResponseFuture, handler::EngineIoHandler, - sid_generator::Sid, protocol::ProtocolVersion, + sid_generator::Sid }; use bytes::Bytes; use cfg_if::cfg_if; @@ -212,6 +212,24 @@ impl FromStr for TransportType { } } +#[derive(Debug, Clone, PartialEq)] +pub enum ProtocolVersion { + V3 = 3, + V4 = 4, +} + +impl FromStr for ProtocolVersion { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s { + "3" => Ok(ProtocolVersion::V3), + "4" => Ok(ProtocolVersion::V4), + _ => Err(Error::UnsupportedProtocolVersion), + } + } +} + /// The request information extracted from the request URI. #[derive(Debug)] struct RequestInfo { diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 6973d95a..19d9b20e 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -16,9 +16,9 @@ use tracing::debug; use crate::{ config::EngineIoConfig, errors::Error, handler::EngineIoHandler, packet::Packet, - utils::forward_map_chan, SendPacket, + utils::forward_map_chan, SendPacket, service::ProtocolVersion, }; -use crate::{protocol::ProtocolVersion, sid_generator::Sid}; +use crate::sid_generator::Sid; #[derive(Debug, Clone, PartialEq)] pub(crate) enum ConnectionType { From c2cc6643c97c8a4c86bece9b5ad0703a87202417 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 09:54:12 +0200 Subject: [PATCH 22/43] Refactor: derrive clone on protocol enum --- engineioxide/src/engine.rs | 6 +++--- engineioxide/src/service.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index ae1e9079..1e4cddde 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -69,7 +69,7 @@ impl EngineIo let sid = generate_sid(); let socket = Socket::new( sid, - protocol.clone(), + protocol, ConnectionType::Http, &self.config, SocketReq::from(req.into_parts().0), @@ -285,7 +285,7 @@ impl EngineIo Some(_) => { debug!("[sid={sid}] websocket connection upgrade"); let mut ws = ws_init().await; - self.ws_upgrade_handshake(protocol.clone(), sid, &mut ws).await?; + self.ws_upgrade_handshake(protocol, sid, &mut ws).await?; (self.get_socket(sid).unwrap(), ws) } } @@ -295,7 +295,7 @@ impl EngineIo let close_fn = Box::new(move |sid: Sid| engine.close_session(sid)); let socket = Socket::new( sid, - protocol.clone(), + protocol, ConnectionType::WebSocket, &self.config, req_data, diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index 64ea013a..6229d6e0 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -212,7 +212,7 @@ impl FromStr for TransportType { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq)] pub enum ProtocolVersion { V3 = 3, V4 = 4, From 9cedee5f7b467112753b01473fbd76a545057afb Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 09:55:24 +0200 Subject: [PATCH 23/43] Refactor: remove unnecessary clone on msg --- engineioxide/src/engine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 1e4cddde..89e3750e 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -354,7 +354,7 @@ impl EngineIo while let Ok(msg) = rx.try_next().await { let Some(msg) = msg else { continue }; match msg { - Message::Text(msg) => match Packet::try_from(msg.clone())? { + Message::Text(msg) => match Packet::try_from(msg)? { Packet::Close => { debug!("[sid={}] closing session", socket.sid); self.close_session(socket.sid); From 2270748ba89a878c9e2a3c1ba65e3380867eaa23 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 09:59:25 +0200 Subject: [PATCH 24/43] Refactor: add comments --- engineioxide/src/engine.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 89e3750e..5c541cfe 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -89,6 +89,8 @@ impl EngineIo let mut packet: String = Packet::Open(packet).try_into()?; cfg_if! { if #[cfg(feature = "v3")] { + // The V3 protocol requires the packet length to be prepended to the packet. + // It doesn't use a packet separator like the V4 protocol (and up). if protocol == ProtocolVersion::V3 { packet = format!("{}:{}", packet.chars().count(), packet); } @@ -132,6 +134,8 @@ impl EngineIo debug!("sending packet: {:?}", packet); let packet: String = packet.try_into().unwrap(); if !data.is_empty() { + // The V3 protocol requires the packet length to be prepended to the packet. + // The V4 protocol (and up) only requires a packet separator. match protocol { ProtocolVersion::V3 => data.push_str(&format!("{}:", packet.chars().count())), ProtocolVersion::V4 => data.push(std::char::from_u32(PACKET_SEPARATOR as u32).unwrap()), @@ -148,6 +152,7 @@ impl EngineIo let packet: String = packet.try_into().unwrap(); cfg_if! { if #[cfg(feature = "v3")] { + // The V3 protocol specifically requires the packet length to be prepended to the packet. if protocol == ProtocolVersion::V3 { data.push_str(&format!("{}:", packet.chars().count())); } From 62ab7b42a4b8648fd08c655e47259cf234503b94 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 10:07:24 +0200 Subject: [PATCH 25/43] Refactor: remove unneeded cfg_if statements --- engineioxide/src/payload.rs | 36 +++++++++--------- engineioxide/src/socket.rs | 76 ++++++++++++++++++------------------- 2 files changed, 54 insertions(+), 58 deletions(-) diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index b41eb143..3c86412c 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -1,5 +1,4 @@ use std::{io::BufRead, vec}; -use cfg_if::cfg_if; use crate::service::ProtocolVersion; @@ -96,28 +95,27 @@ impl Payload { impl Iterator for Payload { type Item = Item; - cfg_if! { - if #[cfg(all(feature = "v3", feature = "v4"))] { - fn next(&mut self) -> Option { - match self.protocol { - ProtocolVersion::V3 => { - self.next_v3() - }, - ProtocolVersion::V4 => { - self.next_v4() - }, - } - } - } else if #[cfg(feature = "v3")] { - fn next(&mut self) -> Option { + #[cfg(all(feature = "v3", feature = "v4"))] + fn next(&mut self) -> Option { + match self.protocol { + ProtocolVersion::V3 => { self.next_v3() - } - } else { - fn next(&mut self) -> Option { + }, + ProtocolVersion::V4 => { self.next_v4() - } + }, } } + + #[cfg(feature = "v3")] + fn next(&mut self) -> Option { + self.next_v3() + } + + #[cfg(feature = "v4")] + fn next(&mut self) -> Option { + self.next_v4() + } } #[cfg(test)] diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 19d9b20e..57a6e887 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -6,7 +6,6 @@ use std::{ time::Duration, }; -use cfg_if::cfg_if; use http::{request::Parts, Uri}; use tokio::{ sync::{mpsc, mpsc::Receiver, Mutex}, @@ -181,50 +180,49 @@ where .replace(handle); } - cfg_if! { - if #[cfg(all(feature = "v3", feature = "v4"))] { - /// Heartbeat is sent every `interval` milliseconds and the client or server (depending on the protocol) is expected to respond within `timeout` milliseconds. - /// - /// If the client or server does not respond within the timeout, the connection is closed. - async fn heartbeat_job( - &self, - interval: Duration, - timeout: Duration, - ) -> Result<(), Error> { - match self.protocol { - ProtocolVersion::V3 => { - self.heartbeat_job_v3(timeout).await - } - ProtocolVersion::V4 => { - self.heartbeat_job_v4(interval, timeout).await - } - } - } - } else if #[cfg(feature = "v3")] { - /// Heartbeat is sent every `interval` milliseconds by the client and the server is expected to respond within `timeout` milliseconds. - /// - /// If the client or server does not respond within the timeout, the connection is closed. - async fn heartbeat_job( - &self, - interval: Duration, - timeout: Duration, - ) -> Result<(), Error> { - self.heartbeat_job_v3(timeout) + /// Heartbeat is sent every `interval` milliseconds and the client or server (depending on the protocol) is expected to respond within `timeout` milliseconds. + /// + /// If the client or server does not respond within the timeout, the connection is closed. + #[cfg(all(feature = "v3", feature = "v4"))] + async fn heartbeat_job( + &self, + interval: Duration, + timeout: Duration, + ) -> Result<(), Error> { + match self.protocol { + ProtocolVersion::V3 => { + self.heartbeat_job_v3(timeout).await } - } else { - /// Heartbeat is sent every `interval` milliseconds and the client is expected to respond within `timeout` milliseconds. - /// - /// If the client does not respond within the timeout, the connection is closed. - async fn heartbeat_job( - &self, - interval: Duration, - timeout: Duration, - ) -> Result<(), Error> { + ProtocolVersion::V4 => { self.heartbeat_job_v4(interval, timeout).await } } } + /// Heartbeat is sent every `interval` milliseconds by the client and the server is expected to respond within `timeout` milliseconds. + /// + /// If the client or server does not respond within the timeout, the connection is closed. + #[cfg(feature = "v3")] + async fn heartbeat_job( + &self, + interval: Duration, + timeout: Duration, + ) -> Result<(), Error> { + self.heartbeat_job_v3(timeout) + } + + /// Heartbeat is sent every `interval` milliseconds and the client is expected to respond within `timeout` milliseconds. + /// + /// If the client does not respond within the timeout, the connection is closed. + #[cfg(feature = "v4")] + async fn heartbeat_job( + &self, + interval: Duration, + timeout: Duration, + ) -> Result<(), Error> { + self.heartbeat_job_v4(interval, timeout).await + } + /// Heartbeat is sent every `interval` milliseconds and the client is expected to respond within `timeout` milliseconds. /// /// If the client does not respond within the timeout, the connection is closed. From 59cf42b6b5cf80b93e1c90e69d2b1a772b24794f Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 10:39:28 +0200 Subject: [PATCH 26/43] Refactor: return proper error types --- engineioxide/src/errors.rs | 3 +++ engineioxide/src/payload.rs | 34 +++++++++++++++++----------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/engineioxide/src/errors.rs b/engineioxide/src/errors.rs index 5e307019..1adf9263 100644 --- a/engineioxide/src/errors.rs +++ b/engineioxide/src/errors.rs @@ -48,6 +48,9 @@ pub enum Error { TransportMismatch, #[error("unsupported protocol version")] UnsupportedProtocolVersion, + + #[error("Invalid packet length")] + InvalidPacketLength, } /// Convert an error into an http response diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index 3c86412c..ae446fa2 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -1,6 +1,6 @@ use std::{io::BufRead, vec}; -use crate::service::ProtocolVersion; +use crate::{service::ProtocolVersion, errors::Error}; pub const PACKET_SEPARATOR: u8 = b'\x1e'; @@ -13,7 +13,7 @@ pub struct Payload { protocol: ProtocolVersion, } -type Item = Result; +type Item = Result; impl Payload { pub fn new(protocol: ProtocolVersion, data: R) -> Self { @@ -41,10 +41,10 @@ impl Payload { if let Ok(l) = s.parse::() { l } else { - return Some(Err("Invalid packet length".into())); + return Some(Err(Error::InvalidPacketLength)); } }, - Err(_) => return Some(Err("Invalid packet length".into())), + Err(_) => return Some(Err(Error::InvalidPacketLength)), }; self.buffer.clear(); @@ -54,16 +54,16 @@ impl Payload { Ok(_) => { match String::from_utf8(self.buffer.clone()) { Ok(s) => Some(Ok(s)), - Err(_) => Some(Err("Invalid packet data".into())), + Err(e) => Some(Err(Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))), } }, - Err(err) => Some(Err(err.to_string())), + Err(e) => Some(Err(Error::Io(e))), } } else { None } } - Err(err) => Some(Err(err.to_string())), + Err(e) => Some(Err(Error::Io(e))), } } @@ -81,13 +81,13 @@ impl Payload { match String::from_utf8( self.buffer.clone()) { Ok(s) => Some(Ok(s)), - Err(_) => Some(Err("Packet is not a valid UTF-8 string".into())), + Err(e) => Some(Err(Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))), } } else { None } } - Err(err) => Some(Err(err.to_string())), + Err(e) => Some(Err(Error::Io(e))), } } } @@ -135,10 +135,10 @@ mod tests { ])); let mut payload = Payload::new(ProtocolVersion::V4, data); - assert_eq!(payload.next(), Some(Ok("foo".into()))); - assert_eq!(payload.next(), Some(Ok("fo".into()))); - assert_eq!(payload.next(), Some(Ok("f".into()))); - assert_eq!(payload.next(), None); + assert_eq!(payload.next().unwrap().unwrap(), "foo"); + assert_eq!(payload.next().unwrap().unwrap(), "fo"); + assert_eq!(payload.next().unwrap().unwrap(), "f"); + assert_eq!(payload.next().is_none(), true); Ok(()) } @@ -152,10 +152,10 @@ mod tests { ])); let mut payload = Payload::new(ProtocolVersion::V3, data); - assert_eq!(payload.next(), Some(Ok("foo".into()))); - assert_eq!(payload.next(), Some(Ok("fo".into()))); - assert_eq!(payload.next(), Some(Ok("f".into()))); - assert_eq!(payload.next(), None); + assert_eq!(payload.next().unwrap().unwrap(), "foo"); + assert_eq!(payload.next().unwrap().unwrap(), "fo"); + assert_eq!(payload.next().unwrap().unwrap(), "f"); + assert_eq!(payload.next().is_none(), true); Ok(()) } From 79c6f41fafe1e32afd5ea513cd2a5fb8eaebeb06 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 10:41:20 +0200 Subject: [PATCH 27/43] Fix: fix duplicate definitions error --- engineioxide/src/payload.rs | 2 ++ engineioxide/src/socket.rs | 2 ++ 2 files changed, 4 insertions(+) diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index ae446fa2..a13fc7a2 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -108,11 +108,13 @@ impl Iterator for Payload { } #[cfg(feature = "v3")] + #[cfg(not(feature = "v4"))] fn next(&mut self) -> Option { self.next_v3() } #[cfg(feature = "v4")] + #[cfg(not(feature = "v3"))] fn next(&mut self) -> Option { self.next_v4() } diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 57a6e887..9f007b9b 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -203,6 +203,7 @@ where /// /// If the client or server does not respond within the timeout, the connection is closed. #[cfg(feature = "v3")] + #[cfg(not(feature = "v4"))] async fn heartbeat_job( &self, interval: Duration, @@ -215,6 +216,7 @@ where /// /// If the client does not respond within the timeout, the connection is closed. #[cfg(feature = "v4")] + #[cfg(not(feature = "v3"))] async fn heartbeat_job( &self, interval: Duration, From b6d689d33e7e4df5f64d59b94bd291f131bdef3b Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 11:04:08 +0200 Subject: [PATCH 28/43] Refactor: optimize buffer consumption --- engineioxide/src/payload.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index a13fc7a2..1fc72676 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -26,8 +26,6 @@ impl Payload { #[cfg(feature = "v3")] fn next_v3(&mut self) -> Option { - self.buffer.clear(); - match self.reader.read_until(b':', &mut self.buffer) { Ok(bytes_read) => { if bytes_read > 0 { @@ -36,7 +34,8 @@ impl Payload { self.buffer.pop(); } - let length = match String::from_utf8(self.buffer.clone()) { + let buffer = std::mem::take(&mut self.buffer); + let length = match String::from_utf8(buffer) { Ok(s) => { if let Ok(l) = s.parse::() { l @@ -47,12 +46,12 @@ impl Payload { Err(_) => return Some(Err(Error::InvalidPacketLength)), }; - self.buffer.clear(); self.buffer.resize(length, 0); match self.reader.read_exact(&mut self.buffer) { Ok(_) => { - match String::from_utf8(self.buffer.clone()) { + let buffer = std::mem::take(&mut self.buffer); + match String::from_utf8(buffer) { Ok(s) => Some(Ok(s)), Err(e) => Some(Err(Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))), } @@ -69,8 +68,6 @@ impl Payload { #[cfg(feature = "v4")] fn next_v4(&mut self) -> Option { - self.buffer.clear(); - match self.reader.read_until(PACKET_SEPARATOR, &mut self.buffer) { Ok(bytes_read) => { if bytes_read > 0 { @@ -79,7 +76,8 @@ impl Payload { self.buffer.pop(); } - match String::from_utf8( self.buffer.clone()) { + let buffer = std::mem::take(&mut self.buffer); + match String::from_utf8(buffer) { Ok(s) => Some(Ok(s)), Err(e) => Some(Err(Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))), } From 72d368ff8f64a5c13183921e0f6f484974c540e4 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 11:11:55 +0200 Subject: [PATCH 29/43] Refactor: rewrite to remove `unused_mut` override --- engineioxide/src/engine.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 5c541cfe..25ec3c86 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -85,17 +85,21 @@ impl EngineIo self.handler.on_connect(&socket); let packet = OpenPacket::new(TransportType::Polling, sid, &self.config); - #[allow(unused_mut)] - let mut packet: String = Packet::Open(packet).try_into()?; - cfg_if! { - if #[cfg(feature = "v3")] { + let packet: String = Packet::Open(packet).try_into()?; + let packet = { + #[cfg(feature = "v3")] + { + let mut packet = packet; // The V3 protocol requires the packet length to be prepended to the packet. // It doesn't use a packet separator like the V4 protocol (and up). if protocol == ProtocolVersion::V3 { packet = format!("{}:{}", packet.chars().count(), packet); } + packet } - } + #[cfg(not(feature = "v3"))] + packet + }; http_response(StatusCode::OK, packet).map_err(Error::Http) } From 104b82f11fec6d36fff3964975fe87eff5703b43 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 11:16:35 +0200 Subject: [PATCH 30/43] Test: check protocol version in unit tests --- engineioxide/src/service.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index 6229d6e0..15118509 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -317,6 +317,7 @@ mod tests { let info = RequestInfo::parse(&req).unwrap(); assert_eq!(info.sid, None); assert_eq!(info.transport, TransportType::Polling); + assert_eq!(info.protocol, ProtocolVersion::V4); assert_eq!(info.method, Method::GET); } @@ -326,28 +327,31 @@ mod tests { let info = RequestInfo::parse(&req).unwrap(); assert_eq!(info.sid, None); assert_eq!(info.transport, TransportType::Websocket); + assert_eq!(info.protocol, ProtocolVersion::V4); assert_eq!(info.method, Method::GET); } #[test] fn request_info_polling_with_sid() { let req = build_request( - "http://localhost:3000/socket.io/?EIO=4&transport=polling&sid=AAAAAAAAAHs", + "http://localhost:3000/socket.io/?EIO=3&transport=polling&sid=AAAAAAAAAHs", ); let info = RequestInfo::parse(&req).unwrap(); assert_eq!(info.sid, Some(123i64.into())); assert_eq!(info.transport, TransportType::Polling); + assert_eq!(info.protocol, ProtocolVersion::V3); assert_eq!(info.method, Method::GET); } #[test] fn request_info_websocket_with_sid() { let req = build_request( - "http://localhost:3000/socket.io/?EIO=4&transport=websocket&sid=AAAAAAAAAHs", + "http://localhost:3000/socket.io/?EIO=3&transport=websocket&sid=AAAAAAAAAHs", ); let info = RequestInfo::parse(&req).unwrap(); assert_eq!(info.sid, Some(123i64.into())); assert_eq!(info.transport, TransportType::Websocket); + assert_eq!(info.protocol, ProtocolVersion::V3); assert_eq!(info.method, Method::GET); } #[test] From 811a78067397c8c9ab38787acc4b766314c9998d Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 11:34:59 +0200 Subject: [PATCH 31/43] Chore: upate engineio CI config --- .github/workflows/engineio-ci.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/engineio-ci.yml b/.github/workflows/engineio-ci.yml index f6d1fb20..499098fc 100644 --- a/.github/workflows/engineio-ci.yml +++ b/.github/workflows/engineio-ci.yml @@ -21,10 +21,9 @@ jobs: toolchain: stable - uses: actions/checkout@v3 with: - # TODO: replace with official repo once https://github.com/socketio/engine.io-protocol/pull/45 has been merged - repository: sleeyax/engine.io-protocol + repository: socketio/engine.io-protocol path: engine.io-protocol - ref: v3-testsuite + ref: v3 - uses: actions/setup-node@v3 with: node-version: 16 From 2b430c992be572a686700bc360959c03dea4faaa Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 11:42:29 +0200 Subject: [PATCH 32/43] Refactor: improve binary packet matching --- engineioxide/src/packet.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/engineioxide/src/packet.rs b/engineioxide/src/packet.rs index 5c0122f4..9b07ba89 100644 --- a/engineioxide/src/packet.rs +++ b/engineioxide/src/packet.rs @@ -110,13 +110,8 @@ impl TryFrom for Packet { '4' => Packet::Message(packet_data.to_string()), '5' => Packet::Upgrade, '6' => Packet::Noop, - 'b' => { - if value.starts_with("b4") { - Packet::BinaryV3(general_purpose::STANDARD.decode(packet_data[1..].as_bytes())?) - } else { - Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?) - } - }, + 'b' if value.starts_with("b4") => Packet::BinaryV3(general_purpose::STANDARD.decode(packet_data[1..].as_bytes())?), + 'b' => Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?), c => Err(serde_json::Error::custom( "Invalid packet type ".to_string() + &c.to_string(), ))?, From 5d19239184b70d16a9e525ba72e2b5cbe019810c Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 11:46:55 +0200 Subject: [PATCH 33/43] Refactor: simplify error mapping --- engineioxide/src/payload.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index 1fc72676..2829ae5d 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -51,10 +51,7 @@ impl Payload { match self.reader.read_exact(&mut self.buffer) { Ok(_) => { let buffer = std::mem::take(&mut self.buffer); - match String::from_utf8(buffer) { - Ok(s) => Some(Ok(s)), - Err(e) => Some(Err(Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))), - } + Some(String::from_utf8(buffer).map_err(Into::into)) }, Err(e) => Some(Err(Error::Io(e))), } @@ -77,10 +74,7 @@ impl Payload { } let buffer = std::mem::take(&mut self.buffer); - match String::from_utf8(buffer) { - Ok(s) => Some(Ok(s)), - Err(e) => Some(Err(Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))), - } + Some(String::from_utf8(buffer).map_err(Into::into)) } else { None } From d3bdd83efdddca68f869539c6f2bde39665523ac Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 11:50:18 +0200 Subject: [PATCH 34/43] Refactor: combine packet matching --- engineioxide/src/engine.rs | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 25ec3c86..d6abcb16 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -203,39 +203,35 @@ impl EngineIo ) })?; - let packet = match Packet::try_from(raw_packet) { - Ok(p) => p, - Err(e) => { - debug!("[sid={sid}] error parsing packet: {:?}", e); - self.close_session(sid); - return Err(e); - } - }; - - match packet { - Packet::Close => { + match Packet::try_from(raw_packet) { + Ok(Packet::Close) => { debug!("[sid={sid}] closing session"); socket.send(Packet::Noop)?; self.close_session(sid); break; } - Packet::Pong | Packet::Ping => { + Ok(Packet::Pong) | Ok(Packet::Ping) => { socket .pong_tx .try_send(()) .map_err(|_| Error::HeartbeatTimeout) }, - Packet::Message(msg) => { + Ok(Packet::Message(msg)) => { self.handler.on_message(msg, &socket); Ok(()) } - Packet::Binary(bin) | Packet::BinaryV3(bin) => { + Ok(Packet::Binary(bin)) | Ok(Packet::BinaryV3(bin)) => { self.handler.on_binary(bin, &socket); Ok(()) } - p => { + Ok(p) => { debug!("[sid={sid}] bad packet received: {:?}", &p); Err(Error::BadPacket(p)) + }, + Err(e) => { + debug!("[sid={sid}] error parsing packet: {:?}", e); + self.close_session(sid); + return Err(e); } }?; } From 857c0a59408c5a455099936c8c2ac3d427d7e28f Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 14:49:10 +0200 Subject: [PATCH 35/43] Refactor: simplify v3 payload parsing --- engineioxide/src/payload.rs | 46 ++++++++++++------------------------- 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index 2829ae5d..27596c35 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -27,38 +27,22 @@ impl Payload { #[cfg(feature = "v3")] fn next_v3(&mut self) -> Option { match self.reader.read_until(b':', &mut self.buffer) { - Ok(bytes_read) => { - if bytes_read > 0 { - // remove trailing separator - if self.buffer.ends_with(&[b':']) { - self.buffer.pop(); - } - - let buffer = std::mem::take(&mut self.buffer); - let length = match String::from_utf8(buffer) { - Ok(s) => { - if let Ok(l) = s.parse::() { - l - } else { - return Some(Err(Error::InvalidPacketLength)); - } - }, - Err(_) => return Some(Err(Error::InvalidPacketLength)), - }; - - self.buffer.resize(length, 0); - - match self.reader.read_exact(&mut self.buffer) { - Ok(_) => { - let buffer = std::mem::take(&mut self.buffer); - Some(String::from_utf8(buffer).map_err(Into::into)) - }, - Err(e) => Some(Err(Error::Io(e))), - } - } else { - None + Ok(bytes_read) => (bytes_read > 0).then(|| { + if self.buffer.ends_with(&[b':']) { + self.buffer.pop(); } - } + + let buffer = std::mem::take(&mut self.buffer); + let length = String::from_utf8(buffer) + .map_err(|_| Error::InvalidPacketLength) + .and_then(|s| s.parse::().map_err(|_| Error::InvalidPacketLength))?; + + self.buffer.resize(length, 0); + self.reader.read_exact(&mut self.buffer)?; + + let buffer = std::mem::take(&mut self.buffer); + String::from_utf8(buffer).map_err(Into::into) + }), Err(e) => Some(Err(Error::Io(e))), } } From 711e8e32c38147356e6aad56c441587541e2ca07 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 15:02:49 +0200 Subject: [PATCH 36/43] Refactor: rename pong rx,tx to heartbeat rx,tx This makes the name agnositc from the protocol version that's being used. --- engineioxide/src/engine.rs | 4 ++-- engineioxide/src/socket.rs | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index d6abcb16..7d0f90f4 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -212,7 +212,7 @@ impl EngineIo } Ok(Packet::Pong) | Ok(Packet::Ping) => { socket - .pong_tx + .heartbeat_tx .try_send(()) .map_err(|_| Error::HeartbeatTimeout) }, @@ -367,7 +367,7 @@ impl EngineIo } Packet::Pong | Packet::Ping => { socket - .pong_tx + .heartbeat_tx .try_send(()) .map_err(|_| Error::HeartbeatTimeout) }, diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 9f007b9b..7af9b25b 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -89,12 +89,12 @@ where internal_tx: mpsc::Sender, pub tx: mpsc::Sender, - /// Internal channel to receive Pong [`Packets`](Packet) in the heartbeat job + /// Internal channel to receive Pong [`Packets`](Packet) (v4 protocol) or Ping (v3 protocol) in the heartbeat job /// which is running in a separate task - pong_rx: Mutex>, - /// Channel to send Ping [`Packets`](Packet) from the connexion to the heartbeat job + heartbeat_rx: Mutex>, + /// Channel to send Ping [`Packets`](Packet) (v4 protocol) or Ping (v3 protocol) from the connexion to the heartbeat job /// which is running in a separate task - pub(crate) pong_tx: mpsc::Sender<()>, + pub(crate) heartbeat_tx: mpsc::Sender<()>, /// Handle to the heartbeat job so that it can be aborted when the socket is closed heartbeat_handle: Mutex>>, @@ -134,8 +134,8 @@ where internal_tx, tx, - pong_rx: Mutex::new(pong_rx), - pong_tx, + heartbeat_rx: Mutex::new(pong_rx), + heartbeat_tx: pong_tx, heartbeat_handle: Mutex::new(None), close_fn, @@ -235,7 +235,7 @@ where timeout: Duration, ) -> Result<(), Error> { let mut pong_rx = self - .pong_rx + .heartbeat_rx .try_lock() .expect("Pong rx should be locked only once"); @@ -360,8 +360,8 @@ impl Socket { internal_tx, tx, - pong_rx: Mutex::new(pong_rx), - pong_tx, + heartbeat_rx: Mutex::new(pong_rx), + heartbeat_tx: pong_tx, heartbeat_handle: Mutex::new(None), close_fn, From 7eee19a2833def8f2dd13be4b8464ff42965e8a5 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 15:24:46 +0200 Subject: [PATCH 37/43] Fix: --- engineioxide/src/lib.rs | 7 +++++++ engineioxide/src/service.rs | 14 +------------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/engineioxide/src/lib.rs b/engineioxide/src/lib.rs index dfa01102..dc69be47 100644 --- a/engineioxide/src/lib.rs +++ b/engineioxide/src/lib.rs @@ -2,6 +2,13 @@ pub use async_trait::async_trait; /// A Packet type to use when sending data to the client pub use packet::SendPacket; +use cfg_if::cfg_if; + +cfg_if! { + if #[cfg(not(any(feature = "v3", feature = "v4")))] { + compile_error!("At least one protocol version must be enabled"); + } +} pub mod config; pub mod errors; diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index 15118509..8b7cc4d4 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -256,22 +256,10 @@ impl RequestInfo { .and_then(|t| t.parse())?; cfg_if! { - if #[cfg(all(feature = "v3", feature = "v4"))] { + if #[cfg(any(feature = "v3", feature = "v4"))] { if protocol != ProtocolVersion::V3 && protocol != ProtocolVersion::V4 { return Err(Error::UnsupportedProtocolVersion); } - } - else if #[cfg(feature = "v4")] { - if protocol != ProtocolVersion::V4 { - return Err(Error::UnsupportedProtocolVersion); - } - } - else if #[cfg(feature = "v3")] { - if protocol != ProtocolVersion::V3 { - return Err(Error::UnsupportedProtocolVersion); - } - } else { - compile_error!("At least one protocol version must be enabled"); } } From 728891d83a5f8f68a73cd4de46a046c826ef019c Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 15:30:17 +0200 Subject: [PATCH 38/43] Fix: rename remainging pong rx,tx references I was a bit too fast and forgot to rename a few references, causing compiler errors. This commit fixes this. --- engineioxide/src/socket.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 7af9b25b..042053fd 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -121,7 +121,7 @@ where ) -> Self { let (internal_tx, internal_rx) = mpsc::channel(config.max_buffer_size); let (tx, rx) = mpsc::channel(config.max_buffer_size); - let (pong_tx, pong_rx) = mpsc::channel(1); + let (heartbeat_tx, heartbeat_rx) = mpsc::channel(1); tokio::spawn(forward_map_chan(rx, internal_tx.clone(), SendPacket::into)); @@ -134,8 +134,8 @@ where internal_tx, tx, - heartbeat_rx: Mutex::new(pong_rx), - heartbeat_tx: pong_tx, + heartbeat_rx: Mutex::new(heartbeat_rx), + heartbeat_tx, heartbeat_handle: Mutex::new(None), close_fn, @@ -234,7 +234,7 @@ where interval: Duration, timeout: Duration, ) -> Result<(), Error> { - let mut pong_rx = self + let mut heartbeat_rx = self .heartbeat_rx .try_lock() .expect("Pong rx should be locked only once"); @@ -252,12 +252,12 @@ where loop { // Some clients send the pong packet in first. If that happens, we should consume it. - pong_rx.try_recv().ok(); + heartbeat_rx.try_recv().ok(); self.internal_tx .try_send(Packet::Ping) .map_err(|_| Error::HeartbeatTimeout)?; - tokio::time::timeout(timeout, pong_rx.recv()) + tokio::time::timeout(timeout, heartbeat_rx.recv()) .await .map_err(|_| Error::HeartbeatTimeout)? .ok_or(Error::HeartbeatTimeout)?; @@ -270,15 +270,15 @@ where &self, timeout: Duration, ) -> Result<(), Error> { - let mut pong_rx = self - .pong_rx + let mut heartbeat_rx = self + .heartbeat_rx .try_lock() .expect("Pong rx should be locked only once"); debug!("[sid={}] heartbeat receiver routine started", self.sid); loop { - tokio::time::timeout(timeout, pong_rx.recv()) + tokio::time::timeout(timeout, heartbeat_rx.recv()) .await .map_err(|_| Error::HeartbeatTimeout)? .ok_or(Error::HeartbeatTimeout)?; @@ -347,7 +347,7 @@ impl Socket { pub fn new_dummy(sid: Sid, close_fn: Box) -> Socket { let (internal_tx, internal_rx) = mpsc::channel(200); let (tx, rx) = mpsc::channel(200); - let (pong_tx, pong_rx) = mpsc::channel(1); + let (heartbeat_tx, heartbeat_rx) = mpsc::channel(1); tokio::spawn(forward_map_chan(rx, internal_tx.clone(), SendPacket::into)); @@ -360,8 +360,8 @@ impl Socket { internal_tx, tx, - heartbeat_rx: Mutex::new(pong_rx), - heartbeat_tx: pong_tx, + heartbeat_rx: Mutex::new(heartbeat_rx), + heartbeat_tx, heartbeat_handle: Mutex::new(None), close_fn, From e05212379d423186562786bce3c9e939b1ddcd9d Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 15:39:30 +0200 Subject: [PATCH 39/43] Refactor: move protcol validation to FromStr implementation --- engineioxide/src/service.rs | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index 8b7cc4d4..68131de1 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -11,7 +11,6 @@ use crate::{ sid_generator::Sid }; use bytes::Bytes; -use cfg_if::cfg_if; use futures::future::{ready, Ready}; use http::{Method, Request}; use http_body::{Body, Empty}; @@ -221,6 +220,7 @@ pub enum ProtocolVersion { impl FromStr for ProtocolVersion { type Err = Error; + #[cfg(all(feature = "v3", feature = "v4"))] fn from_str(s: &str) -> Result { match s { "3" => Ok(ProtocolVersion::V3), @@ -228,6 +228,24 @@ impl FromStr for ProtocolVersion { _ => Err(Error::UnsupportedProtocolVersion), } } + + #[cfg(feature = "v4")] + #[cfg(not(feature = "v3"))] + fn from_str(s: &str) -> Result { + match s { + "4" => Ok(ProtocolVersion::V4), + _ => Err(Error::UnsupportedProtocolVersion), + } + } + + #[cfg(feature = "v3")] + #[cfg(not(feature = "v4"))] + fn from_str(s: &str) -> Result { + match s { + "3" => Ok(ProtocolVersion::V3), + _ => Err(Error::UnsupportedProtocolVersion), + } + } } /// The request information extracted from the request URI. @@ -255,14 +273,6 @@ impl RequestInfo { .ok_or(UnknownTransport) .and_then(|t| t.parse())?; - cfg_if! { - if #[cfg(any(feature = "v3", feature = "v4"))] { - if protocol != ProtocolVersion::V3 && protocol != ProtocolVersion::V4 { - return Err(Error::UnsupportedProtocolVersion); - } - } - } - let sid = query .split('&') .find(|s| s.starts_with("sid=")) From 38483b7b8a5f937d5a44b54c1a4279d9a26f2e3c Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 16:01:25 +0200 Subject: [PATCH 40/43] Refactor: parse length to str rarher than String --- engineioxide/src/payload.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index 27596c35..0ca3139d 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -33,7 +33,7 @@ impl Payload { } let buffer = std::mem::take(&mut self.buffer); - let length = String::from_utf8(buffer) + let length = std::str::from_utf8(&buffer) .map_err(|_| Error::InvalidPacketLength) .and_then(|s| s.parse::().map_err(|_| Error::InvalidPacketLength))?; From a19176ee343ead4114db4a36c121be77077e9cb7 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 16:18:51 +0200 Subject: [PATCH 41/43] Refactor: remove unneeded cfg-if dependency --- Cargo.lock | 1 - engineioxide/Cargo.toml | 1 - engineioxide/src/engine.rs | 35 ++++++++++++++--------------------- engineioxide/src/lib.rs | 8 ++------ 4 files changed, 16 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0c49a003..c2da666c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -358,7 +358,6 @@ dependencies = [ "base64 0.21.0", "base64id", "bytes", - "cfg-if", "criterion", "futures", "http", diff --git a/engineioxide/Cargo.toml b/engineioxide/Cargo.toml index f26b400c..123b2c80 100644 --- a/engineioxide/Cargo.toml +++ b/engineioxide/Cargo.toml @@ -30,7 +30,6 @@ tower = "0.4.13" tracing = "0.1.37" rand = "0.8.5" base64id = { version = "0.3.1", features = ["std", "rand", "serde"] } -cfg-if = "1.0.0" [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 7d0f90f4..4b4d7357 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -17,7 +17,6 @@ use crate::{ socket::{ConnectionType, Socket, SocketReq}, }; use bytes::Buf; -use cfg_if::cfg_if; use futures::{stream::SplitStream, SinkExt, StreamExt, TryStreamExt}; use http::{Request, Response, StatusCode}; use hyper::upgrade::Upgraded; @@ -154,12 +153,10 @@ impl EngineIo if data.is_empty() { let packet = rx.recv().await.ok_or(Error::Aborted)?; let packet: String = packet.try_into().unwrap(); - cfg_if! { - if #[cfg(feature = "v3")] { - // The V3 protocol specifically requires the packet length to be prepended to the packet. - if protocol == ProtocolVersion::V3 { - data.push_str(&format!("{}:", packet.chars().count())); - } + #[cfg(feature = "v3")] { + // The V3 protocol specifically requires the packet length to be prepended to the packet. + if protocol == ProtocolVersion::V3 { + data.push_str(&format!("{}:", packet.chars().count())); } } data.push_str(&packet); @@ -419,12 +416,10 @@ impl EngineIo ) -> Result<(), Error> { let socket = self.get_socket(sid).unwrap(); - cfg_if! { - if #[cfg(feature = "v4")] { - // send a NOOP packet to any pending polling request so it closes gracefully' - if protocol == ProtocolVersion::V4 { - socket.send(Packet::Noop)?; - } + #[cfg(feature = "v4")] { + // send a NOOP packet to any pending polling request so it closes gracefully' + if protocol == ProtocolVersion::V4 { + socket.send(Packet::Noop)?; } } @@ -442,14 +437,12 @@ impl EngineIo p => Err(Error::BadPacket(p))?, }; - cfg_if! { - if #[cfg(feature = "v3")] { - // send a NOOP packet to any pending polling request so it closes gracefully - // V3 protocol introduce _paused_ polling transport which require to close - // the polling request **after** the ping/pong handshake - if protocol == ProtocolVersion::V3 { - socket.send(Packet::Noop)?; - } + #[cfg(feature = "v3")] { + // send a NOOP packet to any pending polling request so it closes gracefully + // V3 protocol introduce _paused_ polling transport which require to close + // the polling request **after** the ping/pong handshake + if protocol == ProtocolVersion::V3 { + socket.send(Packet::Noop)?; } } diff --git a/engineioxide/src/lib.rs b/engineioxide/src/lib.rs index dc69be47..f68a6c4d 100644 --- a/engineioxide/src/lib.rs +++ b/engineioxide/src/lib.rs @@ -2,13 +2,9 @@ pub use async_trait::async_trait; /// A Packet type to use when sending data to the client pub use packet::SendPacket; -use cfg_if::cfg_if; -cfg_if! { - if #[cfg(not(any(feature = "v3", feature = "v4")))] { - compile_error!("At least one protocol version must be enabled"); - } -} +#[cfg(not(any(feature = "v3", feature = "v4")))] +compile_error!("At least one protocol version must be enabled"); pub mod config; pub mod errors; From a9c3fb657cb6c3d9188a6791656599e7ed1500dd Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 16:23:47 +0200 Subject: [PATCH 42/43] Chore: Add TODO comments for future improvements --- engineioxide/src/payload.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index 0ca3139d..dcb2f80c 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -13,7 +13,7 @@ pub struct Payload { protocol: ProtocolVersion, } -type Item = Result; +type Item = Result; // TODO: refactor to return Result<&str, Error> instead (see TODO's below) impl Payload { pub fn new(protocol: ProtocolVersion, data: R) -> Self { @@ -41,7 +41,7 @@ impl Payload { self.reader.read_exact(&mut self.buffer)?; let buffer = std::mem::take(&mut self.buffer); - String::from_utf8(buffer).map_err(Into::into) + String::from_utf8(buffer).map_err(Into::into) // TODO: replace 'String::from_utf8' with 'std::str::from_utf8' }), Err(e) => Some(Err(Error::Io(e))), } @@ -58,7 +58,7 @@ impl Payload { } let buffer = std::mem::take(&mut self.buffer); - Some(String::from_utf8(buffer).map_err(Into::into)) + Some(String::from_utf8(buffer).map_err(Into::into)) // TODO: replace 'String::from_utf8' with 'std::str::from_utf8' } else { None } From 0e60ced4381dcae6d577b2eb89012dab313ee29b Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Sun, 2 Jul 2023 16:33:47 +0200 Subject: [PATCH 43/43] Chore: apply code formatting --- engineioxide/src/engine.rs | 67 +++++++++++++++++++------------------ engineioxide/src/layer.rs | 14 +++----- engineioxide/src/lib.rs | 2 +- engineioxide/src/packet.rs | 12 ++++--- engineioxide/src/payload.rs | 28 ++++++++++------ engineioxide/src/service.rs | 15 +++++---- engineioxide/src/socket.rs | 55 ++++++++---------------------- socketioxide/src/adapter.rs | 7 ++-- socketioxide/src/client.rs | 2 +- 9 files changed, 88 insertions(+), 114 deletions(-) diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 4b4d7357..74c42cca 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -4,7 +4,6 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::{sid_generator::Sid, payload::{Payload, PACKET_SEPARATOR}, service::ProtocolVersion}; use crate::{ body::ResponseBody, config::EngineIoConfig, @@ -16,6 +15,11 @@ use crate::{ sid_generator::generate_sid, socket::{ConnectionType, Socket, SocketReq}, }; +use crate::{ + payload::{Payload, PACKET_SEPARATOR}, + service::ProtocolVersion, + sid_generator::Sid, +}; use bytes::Buf; use futures::{stream::SplitStream, SinkExt, StreamExt, TryStreamExt}; use http::{Request, Response, StatusCode}; @@ -30,15 +34,13 @@ use tracing::debug; type SocketMap = RwLock>>; /// Abstract engine implementation for Engine.IO server for http polling and websocket /// It handle all the connection logic and dispatch the packets to the socket -pub struct EngineIo -{ +pub struct EngineIo { sockets: SocketMap>, handler: H, pub config: EngineIoConfig, } -impl EngineIo -{ +impl EngineIo { /// Create a new Engine.IO server with a handler and a config pub fn new(handler: H, config: EngineIoConfig) -> Self { Self { @@ -49,8 +51,7 @@ impl EngineIo } } -impl EngineIo -{ +impl EngineIo { /// Handle Open request /// Create a new socket and add it to the socket map /// Start the heartbeat task @@ -141,7 +142,9 @@ impl EngineIo // The V4 protocol (and up) only requires a packet separator. match protocol { ProtocolVersion::V3 => data.push_str(&format!("{}:", packet.chars().count())), - ProtocolVersion::V4 => data.push(std::char::from_u32(PACKET_SEPARATOR as u32).unwrap()), + ProtocolVersion::V4 => { + data.push(std::char::from_u32(PACKET_SEPARATOR as u32).unwrap()) + } } } else if protocol == ProtocolVersion::V3 { data.push_str(&format!("{}:", packet.chars().count())); @@ -153,7 +156,8 @@ impl EngineIo if data.is_empty() { let packet = rx.recv().await.ok_or(Error::Aborted)?; let packet: String = packet.try_into().unwrap(); - #[cfg(feature = "v3")] { + #[cfg(feature = "v3")] + { // The V3 protocol specifically requires the packet length to be prepended to the packet. if protocol == ProtocolVersion::V3 { data.push_str(&format!("{}:", packet.chars().count())); @@ -183,7 +187,7 @@ impl EngineIo debug!("error aggregating body: {:?}", e); Error::HttpErrorResponse(StatusCode::BAD_REQUEST) })?; - + let socket = self .get_socket(sid) .ok_or(Error::UnknownSessionID(sid)) @@ -195,9 +199,7 @@ impl EngineIo let raw_packet = p.map_err(|e| { debug!("error parsing packets: {:?}", e); self.close_session(sid); - Error::HttpErrorResponse( - StatusCode::BAD_REQUEST, - ) + Error::HttpErrorResponse(StatusCode::BAD_REQUEST) })?; match Packet::try_from(raw_packet) { @@ -207,12 +209,10 @@ impl EngineIo self.close_session(sid); break; } - Ok(Packet::Pong) | Ok(Packet::Ping) => { - socket - .heartbeat_tx - .try_send(()) - .map_err(|_| Error::HeartbeatTimeout) - }, + Ok(Packet::Pong) | Ok(Packet::Ping) => socket + .heartbeat_tx + .try_send(()) + .map_err(|_| Error::HeartbeatTimeout), Ok(Packet::Message(msg)) => { self.handler.on_message(msg, &socket); Ok(()) @@ -224,7 +224,7 @@ impl EngineIo Ok(p) => { debug!("[sid={sid}] bad packet received: {:?}", &p); Err(Error::BadPacket(p)) - }, + } Err(e) => { debug!("[sid={sid}] error parsing packet: {:?}", e); self.close_session(sid); @@ -323,7 +323,9 @@ impl EngineIo let mut socket_rx = rx_socket.internal_rx.try_lock().unwrap(); while let Some(item) = socket_rx.recv().await { let res = match item { - Packet::Binary(bin) | Packet::BinaryV3(bin) => tx.send(Message::Binary(bin)).await, + Packet::Binary(bin) | Packet::BinaryV3(bin) => { + tx.send(Message::Binary(bin)).await + } Packet::Close => tx.send(Message::Close(None)).await, _ => { let packet: String = item.try_into().unwrap(); @@ -362,12 +364,10 @@ impl EngineIo self.close_session(socket.sid); break; } - Packet::Pong | Packet::Ping => { - socket - .heartbeat_tx - .try_send(()) - .map_err(|_| Error::HeartbeatTimeout) - }, + Packet::Pong | Packet::Ping => socket + .heartbeat_tx + .try_send(()) + .map_err(|_| Error::HeartbeatTimeout), Packet::Message(msg) => { self.handler.on_message(msg, socket); Ok(()) @@ -416,7 +416,8 @@ impl EngineIo ) -> Result<(), Error> { let socket = self.get_socket(sid).unwrap(); - #[cfg(feature = "v4")] { + #[cfg(feature = "v4")] + { // send a NOOP packet to any pending polling request so it closes gracefully' if protocol == ProtocolVersion::V4 { socket.send(Packet::Noop)?; @@ -437,9 +438,10 @@ impl EngineIo p => Err(Error::BadPacket(p))?, }; - #[cfg(feature = "v3")] { + #[cfg(feature = "v3")] + { // send a NOOP packet to any pending polling request so it closes gracefully - // V3 protocol introduce _paused_ polling transport which require to close + // V3 protocol introduce _paused_ polling transport which require to close // the polling request **after** the ping/pong handshake if protocol == ProtocolVersion::V3 { socket.send(Packet::Noop)?; @@ -452,11 +454,11 @@ impl EngineIo Some(Ok(Message::Close(_))) => { debug!("ws stream closed before upgrade"); Err(Error::UpgradeError)? - }, + } _ => { debug!("unexpected ws message before upgrade"); Err(Error::UpgradeError)? - }, + } }; match Packet::try_from(msg)? { Packet::Upgrade => debug!("[sid={sid}] ws upgraded successful"), @@ -503,7 +505,6 @@ impl EngineIo } } - #[cfg(test)] mod tests { use async_trait::async_trait; diff --git a/engineioxide/src/layer.rs b/engineioxide/src/layer.rs index b8d4109c..1ceae5de 100644 --- a/engineioxide/src/layer.rs +++ b/engineioxide/src/layer.rs @@ -3,14 +3,12 @@ use tower::Layer; use crate::{config::EngineIoConfig, handler::EngineIoHandler, service::EngineIoService}; #[derive(Debug, Clone)] -pub struct EngineIoLayer -{ +pub struct EngineIoLayer { config: EngineIoConfig, handler: H, } -impl EngineIoLayer -{ +impl EngineIoLayer { pub fn new(handler: H) -> Self { Self { config: EngineIoConfig::default(), @@ -18,15 +16,11 @@ impl EngineIoLayer } } pub fn from_config(handler: H, config: EngineIoConfig) -> Self { - Self { - config, - handler, - } + Self { config, handler } } } -impl Layer for EngineIoLayer -{ +impl Layer for EngineIoLayer { type Service = EngineIoService; fn layer(&self, inner: S) -> Self::Service { diff --git a/engineioxide/src/lib.rs b/engineioxide/src/lib.rs index f68a6c4d..75e61e53 100644 --- a/engineioxide/src/lib.rs +++ b/engineioxide/src/lib.rs @@ -18,5 +18,5 @@ mod body; mod engine; mod futures; mod packet; -mod utils; mod payload; +mod utils; diff --git a/engineioxide/src/packet.rs b/engineioxide/src/packet.rs index 9b07ba89..706d6f05 100644 --- a/engineioxide/src/packet.rs +++ b/engineioxide/src/packet.rs @@ -53,7 +53,7 @@ pub enum Packet { /// Or to a websocket binary frame when using websocket connection /// /// When receiving, it is only used with polling connection, websocket use binary frame - /// + /// /// This is a special packet, excepionally specific to the V3 protocol. BinaryV3(Vec), // Not part of the protocol, used internally } @@ -85,9 +85,9 @@ impl TryFrom for Packet { type Error = crate::errors::Error; fn try_from(value: String) -> Result { let mut chars = value.chars(); - let packet_type = chars.next().ok_or_else(|| serde_json::Error::custom( - "Packet type not found in packet string", - ))?; + let packet_type = chars + .next() + .ok_or_else(|| serde_json::Error::custom("Packet type not found in packet string"))?; let packet_data = chars.as_str(); let is_upgrade = packet_data.starts_with("probe"); let res = match packet_type { @@ -110,7 +110,9 @@ impl TryFrom for Packet { '4' => Packet::Message(packet_data.to_string()), '5' => Packet::Upgrade, '6' => Packet::Noop, - 'b' if value.starts_with("b4") => Packet::BinaryV3(general_purpose::STANDARD.decode(packet_data[1..].as_bytes())?), + 'b' if value.starts_with("b4") => { + Packet::BinaryV3(general_purpose::STANDARD.decode(packet_data[1..].as_bytes())?) + } 'b' => Packet::Binary(general_purpose::STANDARD.decode(packet_data.as_bytes())?), c => Err(serde_json::Error::custom( "Invalid packet type ".to_string() + &c.to_string(), diff --git a/engineioxide/src/payload.rs b/engineioxide/src/payload.rs index dcb2f80c..bff81a0b 100644 --- a/engineioxide/src/payload.rs +++ b/engineioxide/src/payload.rs @@ -1,6 +1,6 @@ use std::{io::BufRead, vec}; -use crate::{service::ProtocolVersion, errors::Error}; +use crate::{errors::Error, service::ProtocolVersion}; pub const PACKET_SEPARATOR: u8 = b'\x1e'; @@ -36,7 +36,7 @@ impl Payload { let length = std::str::from_utf8(&buffer) .map_err(|_| Error::InvalidPacketLength) .and_then(|s| s.parse::().map_err(|_| Error::InvalidPacketLength))?; - + self.buffer.resize(length, 0); self.reader.read_exact(&mut self.buffer)?; @@ -56,7 +56,7 @@ impl Payload { if self.buffer.ends_with(&[PACKET_SEPARATOR]) { self.buffer.pop(); } - + let buffer = std::mem::take(&mut self.buffer); Some(String::from_utf8(buffer).map_err(Into::into)) // TODO: replace 'String::from_utf8' with 'std::str::from_utf8' } else { @@ -74,12 +74,8 @@ impl Iterator for Payload { #[cfg(all(feature = "v3", feature = "v4"))] fn next(&mut self) -> Option { match self.protocol { - ProtocolVersion::V3 => { - self.next_v3() - }, - ProtocolVersion::V4 => { - self.next_v4() - }, + ProtocolVersion::V3 => self.next_v3(), + ProtocolVersion::V4 => self.next_v4(), } } @@ -98,7 +94,10 @@ impl Iterator for Payload { #[cfg(test)] mod tests { - use std::{io::{BufReader, Cursor}, vec}; + use std::{ + io::{BufReader, Cursor}, + vec, + }; use crate::service::ProtocolVersion; @@ -109,7 +108,14 @@ mod tests { assert!(cfg!(feature = "v4")); let data = BufReader::new(Cursor::new(vec![ - b'f', b'o', b'o', PACKET_SEPARATOR, b'f', b'o', PACKET_SEPARATOR, b'f', + b'f', + b'o', + b'o', + PACKET_SEPARATOR, + b'f', + b'o', + PACKET_SEPARATOR, + b'f', ])); let mut payload = Payload::new(ProtocolVersion::V4, data); diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index 68131de1..c23d67cd 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -2,13 +2,10 @@ use crate::{ body::ResponseBody, config::EngineIoConfig, engine::EngineIo, - errors::{ - Error, - Error::{UnknownTransport}, - }, + errors::{Error, Error::UnknownTransport}, futures::ResponseFuture, handler::EngineIoHandler, - sid_generator::Sid + sid_generator::Sid, }; use bytes::Bytes; use futures::future::{ready, Ready}; @@ -111,13 +108,17 @@ where sid: Some(sid), transport: TransportType::Polling, method: Method::GET, - }) => ResponseFuture::async_response(Box::pin(engine.on_polling_http_req(protocol, sid))), + }) => ResponseFuture::async_response(Box::pin( + engine.on_polling_http_req(protocol, sid), + )), Ok(RequestInfo { protocol, sid: Some(sid), transport: TransportType::Polling, method: Method::POST, - }) => ResponseFuture::async_response(Box::pin(engine.on_post_http_req(protocol, sid, req))), + }) => ResponseFuture::async_response(Box::pin( + engine.on_post_http_req(protocol, sid, req), + )), Ok(RequestInfo { protocol, sid, diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 042053fd..e6b89c8a 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -13,11 +13,11 @@ use tokio::{ }; use tracing::debug; +use crate::sid_generator::Sid; use crate::{ config::EngineIoConfig, errors::Error, handler::EngineIoHandler, packet::Packet, - utils::forward_map_chan, SendPacket, service::ProtocolVersion, + service::ProtocolVersion, utils::forward_map_chan, SendPacket, }; -use crate::sid_generator::Sid; #[derive(Debug, Clone, PartialEq)] pub(crate) enum ConnectionType { @@ -161,11 +161,7 @@ where /// Spawn the heartbeat job /// /// Keep a handle to the job so that it can be aborted when the socket is closed - pub(crate) fn spawn_heartbeat( - self: Arc, - interval: Duration, - timeout: Duration, - ) { + pub(crate) fn spawn_heartbeat(self: Arc, interval: Duration, timeout: Duration) { let socket = self.clone(); let handle = tokio::spawn(async move { @@ -184,18 +180,10 @@ where /// /// If the client or server does not respond within the timeout, the connection is closed. #[cfg(all(feature = "v3", feature = "v4"))] - async fn heartbeat_job( - &self, - interval: Duration, - timeout: Duration, - ) -> Result<(), Error> { + async fn heartbeat_job(&self, interval: Duration, timeout: Duration) -> Result<(), Error> { match self.protocol { - ProtocolVersion::V3 => { - self.heartbeat_job_v3(timeout).await - } - ProtocolVersion::V4 => { - self.heartbeat_job_v4(interval, timeout).await - } + ProtocolVersion::V3 => self.heartbeat_job_v3(timeout).await, + ProtocolVersion::V4 => self.heartbeat_job_v4(interval, timeout).await, } } @@ -204,24 +192,16 @@ where /// If the client or server does not respond within the timeout, the connection is closed. #[cfg(feature = "v3")] #[cfg(not(feature = "v4"))] - async fn heartbeat_job( - &self, - interval: Duration, - timeout: Duration, - ) -> Result<(), Error> { + async fn heartbeat_job(&self, interval: Duration, timeout: Duration) -> Result<(), Error> { self.heartbeat_job_v3(timeout) } - /// Heartbeat is sent every `interval` milliseconds and the client is expected to respond within `timeout` milliseconds. + /// Heartbeat is sent every `interval` milliseconds and the client is expected to respond within `timeout` milliseconds. /// /// If the client does not respond within the timeout, the connection is closed. #[cfg(feature = "v4")] #[cfg(not(feature = "v3"))] - async fn heartbeat_job( - &self, - interval: Duration, - timeout: Duration, - ) -> Result<(), Error> { + async fn heartbeat_job(&self, interval: Duration, timeout: Duration) -> Result<(), Error> { self.heartbeat_job_v4(interval, timeout).await } @@ -229,11 +209,7 @@ where /// /// If the client does not respond within the timeout, the connection is closed. #[cfg(feature = "v4")] - async fn heartbeat_job_v4( - &self, - interval: Duration, - timeout: Duration, - ) -> Result<(), Error> { + async fn heartbeat_job_v4(&self, interval: Duration, timeout: Duration) -> Result<(), Error> { let mut heartbeat_rx = self .heartbeat_rx .try_lock() @@ -249,7 +225,7 @@ where .await; debug!("[sid={}] heartbeat sender routine started", self.sid); - + loop { // Some clients send the pong packet in first. If that happens, we should consume it. heartbeat_rx.try_recv().ok(); @@ -266,17 +242,14 @@ where } #[cfg(feature = "v3")] - async fn heartbeat_job_v3( - &self, - timeout: Duration, - ) -> Result<(), Error> { + async fn heartbeat_job_v3(&self, timeout: Duration) -> Result<(), Error> { let mut heartbeat_rx = self .heartbeat_rx .try_lock() .expect("Pong rx should be locked only once"); - + debug!("[sid={}] heartbeat receiver routine started", self.sid); - + loop { tokio::time::timeout(timeout, heartbeat_rx.recv()) .await diff --git a/socketioxide/src/adapter.rs b/socketioxide/src/adapter.rs index fc8eba4c..ccde4d7d 100644 --- a/socketioxide/src/adapter.rs +++ b/socketioxide/src/adapter.rs @@ -18,15 +18,12 @@ use itertools::Itertools; use serde::de::DeserializeOwned; use crate::{ - errors::{ - AckError, - BroadcastError, - }, + errors::{AckError, BroadcastError}, handler::AckResponse, ns::Namespace, operators::RoomParam, packet::Packet, - socket::Socket + socket::Socket, }; /// A room identifier diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index 6e05af61..0baa04ec 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -185,4 +185,4 @@ impl Clone for Client { ns: self.ns.clone(), } } -} \ No newline at end of file +}