diff --git a/socketio/src/asynchronous/client/builder.rs b/socketio/src/asynchronous/client/builder.rs index 44710e19..e5b6d628 100644 --- a/socketio/src/asynchronous/client/builder.rs +++ b/socketio/src/asynchronous/client/builder.rs @@ -8,7 +8,7 @@ use rust_engineio::{ use std::collections::HashMap; use url::Url; -use crate::{error::Result, Event, Payload, TransportType}; +use crate::{error::Result, packet::PacketParser, Event, Payload, TransportType}; use super::{ callback::{ @@ -31,6 +31,7 @@ pub struct ClientBuilder { tls_config: Option, opening_headers: Option, transport_type: TransportType, + packet_parser: PacketParser, pub(crate) auth: Option, pub(crate) reconnect: bool, pub(crate) reconnect_on_disconnect: bool, @@ -90,6 +91,7 @@ impl ClientBuilder { tls_config: None, opening_headers: None, transport_type: TransportType::Any, + packet_parser: PacketParser::default(), auth: None, reconnect: true, reconnect_on_disconnect: false, @@ -453,7 +455,7 @@ impl ClientBuilder { TransportType::WebsocketUpgrade => builder.build_websocket_with_upgrade().await?, }; - let inner_socket = InnerSocket::new(engine_client)?; + let inner_socket = InnerSocket::new(engine_client, self.packet_parser.clone())?; Ok(inner_socket) } diff --git a/socketio/src/asynchronous/socket.rs b/socketio/src/asynchronous/socket.rs index 81a9ebcd..b99affd3 100644 --- a/socketio/src/asynchronous/socket.rs +++ b/socketio/src/asynchronous/socket.rs @@ -1,11 +1,10 @@ use super::generator::StreamGenerator; use crate::{ error::Result, - packet::{Packet, PacketId}, + packet::{Packet, PacketId, PacketParser}, Error, Event, Payload, }; use async_stream::try_stream; -use bytes::Bytes; use futures_util::{Stream, StreamExt}; use rust_engineio::{ asynchronous::Client as EngineClient, Packet as EnginePacket, PacketId as EnginePacketId, @@ -24,16 +23,22 @@ pub(crate) struct Socket { engine_client: Arc, connected: Arc, generator: StreamGenerator, + packet_parser: PacketParser, } impl Socket { /// Creates an instance of `Socket`. - pub(super) fn new(engine_client: EngineClient) -> Result { + pub(super) fn new(engine_client: EngineClient, packet_parser: PacketParser) -> Result { let connected = Arc::new(AtomicBool::default()); Ok(Socket { engine_client: Arc::new(engine_client.clone()), connected: connected.clone(), - generator: StreamGenerator::new(Self::stream(engine_client, connected)), + generator: StreamGenerator::new(Self::stream( + engine_client, + connected, + packet_parser.clone(), + )), + packet_parser, }) } @@ -68,7 +73,8 @@ impl Socket { } // the packet, encoded as an engine.io message packet - let engine_packet = EnginePacket::new(EnginePacketId::Message, Bytes::from(&packet)); + let engine_packet = + EnginePacket::new(EnginePacketId::Message, self.packet_parser.encode(&packet)); self.engine_client.emit(engine_packet).await?; if let Some(attachments) = packet.attachments { @@ -92,6 +98,7 @@ impl Socket { fn stream( client: EngineClient, is_connected: Arc, + parser: PacketParser, ) -> Pin> + Send>> { Box::pin(try_stream! { for await received_data in client.clone() { @@ -100,7 +107,7 @@ impl Socket { if packet.packet_id == EnginePacketId::Message || packet.packet_id == EnginePacketId::MessageBinary { - let packet = Self::handle_engineio_packet(packet, client.clone()).await?; + let packet = Self::handle_engineio_packet(packet, client.clone(), &parser).await?; Self::handle_socketio_packet(&packet, is_connected.clone()); yield packet; @@ -130,8 +137,9 @@ impl Socket { async fn handle_engineio_packet( packet: EnginePacket, mut client: EngineClient, + parser: &PacketParser, ) -> Result { - let mut socket_packet = Packet::try_from(&packet.data)?; + let mut socket_packet = parser.decode(&packet.data)?; // Only handle attachments if there are any if socket_packet.attachment_count > 0 { diff --git a/socketio/src/client/builder.rs b/socketio/src/client/builder.rs index 724971f0..fe2a1dce 100644 --- a/socketio/src/client/builder.rs +++ b/socketio/src/client/builder.rs @@ -1,6 +1,7 @@ use super::super::{event::Event, payload::Payload}; use super::callback::Callback; use super::client::Client; +use crate::packet::PacketParser; use crate::RawClient; use native_tls::TlsConnector; use rust_engineio::client::ClientBuilder as EngineIoClientBuilder; @@ -40,6 +41,7 @@ pub struct ClientBuilder { tls_config: Option, opening_headers: Option, transport_type: TransportType, + packet_parser: PacketParser, auth: Option, pub(crate) reconnect: bool, pub(crate) reconnect_on_disconnect: bool, @@ -91,6 +93,7 @@ impl ClientBuilder { tls_config: None, opening_headers: None, transport_type: TransportType::Any, + packet_parser: PacketParser::default(), auth: None, reconnect: true, reconnect_on_disconnect: false, @@ -306,6 +309,13 @@ impl ClientBuilder { self } + /// Specifies how to parser Packet + pub fn packet_parser(mut self, packet_parser: PacketParser) -> Self { + self.packet_parser = packet_parser; + + self + } + /// Connects the socket to a certain endpoint. This returns a connected /// [`Client`] instance. This method returns an [`std::result::Result::Err`] /// value if something goes wrong during connection. Also starts a separate @@ -357,7 +367,7 @@ impl ClientBuilder { TransportType::WebsocketUpgrade => builder.build_websocket_with_upgrade()?, }; - let inner_socket = InnerSocket::new(engine_client)?; + let inner_socket = InnerSocket::new(engine_client, self.packet_parser.clone())?; let socket = RawClient::new( inner_socket, diff --git a/socketio/src/error.rs b/socketio/src/error.rs index cc25d897..c7049a59 100644 --- a/socketio/src/error.rs +++ b/socketio/src/error.rs @@ -20,6 +20,8 @@ pub enum Error { IncompletePacket(), #[error("Got an invalid packet which did not follow the protocol format")] InvalidPacket(), + #[error("Error while parsing an incomplete packet: {0}")] + ParsePacketFailed(String), #[error("An error occurred while decoding the utf-8 text: {0}")] InvalidUtf8(#[from] Utf8Error), #[error("An error occurred while encoding/decoding base64: {0}")] diff --git a/socketio/src/packet.rs b/socketio/src/packet.rs index e74dedb5..bbee50d4 100644 --- a/socketio/src/packet.rs +++ b/socketio/src/packet.rs @@ -4,8 +4,9 @@ use bytes::Bytes; use serde::de::IgnoredAny; use std::convert::TryFrom; -use std::fmt::Write; +use std::fmt::{Debug, Display, Write}; use std::str::from_utf8 as str_from_utf8; +use std::sync::Arc; /// An enumeration of the different `Packet` types in the `socket.io` protocol. #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -30,134 +31,73 @@ pub struct Packet { pub attachments: Option>, } -impl Packet { - /// Returns a packet for a payload, could be used for both binary and non binary - /// events and acks. Convenience method. - #[inline] - pub(crate) fn new_from_payload<'a>( - payload: Payload, - event: Event, - nsp: &'a str, - id: Option, - ) -> Result { - match payload { - Payload::Binary(bin_data) => Ok(Packet::new( - if id.is_some() { - PacketId::BinaryAck - } else { - PacketId::BinaryEvent - }, - nsp.to_owned(), - Some(serde_json::Value::String(event.into()).to_string()), - id, - 1, - Some(vec![bin_data]), - )), - #[allow(deprecated)] - Payload::String(str_data) => { - let payload = if serde_json::from_str::(&str_data).is_ok() { - format!("[\"{event}\",{str_data}]") - } else { - format!("[\"{event}\",\"{str_data}\"]") - }; - - Ok(Packet::new( - PacketId::Event, - nsp.to_owned(), - Some(payload), - id, - 0, - None, - )) - } - Payload::Text(mut data) => { - let mut payload_args = vec![serde_json::Value::String(event.to_string())]; - payload_args.append(&mut data); - drop(data); +#[derive(Clone)] +/// Use to serialize and deserialize packets +/// +/// support [Custom parser](https://socket.io/docs/v4/custom-parser/) +pub struct PacketParser { + encode: Arc Bytes + Send + Sync>>, + decode: Arc Result + Send + Sync>>, +} - let payload = serde_json::Value::Array(payload_args).to_string(); +impl Display for PacketParser { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "PacketSerializer") + } +} - Ok(Packet::new( - PacketId::Event, - nsp.to_owned(), - Some(payload), - id, - 0, - None, - )) - } - } +impl Debug for PacketParser { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PacketSerializer").finish() } } -impl Default for Packet { +impl Default for PacketParser { fn default() -> Self { Self { - packet_type: PacketId::Event, - nsp: String::from("/"), - data: None, - id: None, - attachment_count: 0, - attachments: None, + encode: Arc::new(Box::new(Self::default_encode)), + decode: Arc::new(Box::new(Self::default_decode)), } } } -impl TryFrom for PacketId { - type Error = Error; - fn try_from(b: u8) -> Result { - PacketId::try_from(b as char) +impl PacketParser { + /// Creates a new instance of `PacketSerializer` with both encode and decode functions. + pub fn new( + encode: Box Bytes + Send + Sync>, + decode: Box Result + Send + Sync>, + ) -> Self { + Self { + encode: Arc::new(encode), + decode: Arc::new(decode), + } } -} -impl TryFrom for PacketId { - type Error = Error; - fn try_from(b: char) -> Result { - match b { - '0' => Ok(PacketId::Connect), - '1' => Ok(PacketId::Disconnect), - '2' => Ok(PacketId::Event), - '3' => Ok(PacketId::Ack), - '4' => Ok(PacketId::ConnectError), - '5' => Ok(PacketId::BinaryEvent), - '6' => Ok(PacketId::BinaryAck), - _ => Err(Error::InvalidPacketId(b)), + /// Creates a new instance of `PacketSerializer` with only encode function. and a default decode function. + pub fn new_encode(encode: Box Bytes + Send + Sync>) -> Self { + Self { + encode: Arc::new(encode), + decode: Arc::new(Box::new(Self::default_decode)), } } -} -impl Packet { - /// Creates an instance. - pub const fn new( - packet_type: PacketId, - nsp: String, - data: Option, - id: Option, - attachment_count: u8, - attachments: Option>, - ) -> Self { - Packet { - packet_type, - nsp, - data, - id, - attachment_count, - attachments, + /// Creates a new instance of `PacketSerializer` with only decode function. and a default encode function. + pub fn new_decode(decode: Box Result + Send + Sync>) -> Self { + Self { + encode: Arc::new(Box::new(Self::default_encode)), + decode: Arc::new(decode), } } -} -impl From for Bytes { - fn from(packet: Packet) -> Self { - Bytes::from(&packet) + pub fn encode(&self, packet: &Packet) -> Bytes { + (self.encode)(packet) } -} -impl From<&Packet> for Bytes { - /// Method for encoding from a `Packet` to a `u8` byte stream. - /// The binary payload of a packet is not put at the end of the - /// stream as it gets handled and send by it's own logic via the socket. - fn from(packet: &Packet) -> Bytes { + pub fn decode(&self, payload: &Bytes) -> Result { + (self.decode)(payload) + } + + pub fn default_encode(packet: &Packet) -> Bytes { // first the packet type let mut buffer = String::new(); buffer.push((packet.packet_type as u8 + b'0') as char); @@ -197,31 +137,15 @@ impl From<&Packet> for Bytes { Bytes::from(buffer) } -} - -impl TryFrom for Packet { - type Error = Error; - fn try_from(value: Bytes) -> Result { - Packet::try_from(&value) - } -} -impl TryFrom<&Bytes> for Packet { - type Error = Error; - /// Decodes a packet given a `Bytes` type. - /// The binary payload of a packet is not put at the end of the - /// stream as it gets handled and send by it's own logic via the socket. - /// Therefore this method does not return the correct value for the - /// binary data, instead the socket is responsible for handling - /// this member. This is done because the attachment is usually - /// send in another packet. - fn try_from(payload: &Bytes) -> Result { + pub fn default_decode(payload: &Bytes) -> Result { let mut payload = str_from_utf8(&payload).map_err(Error::InvalidUtf8)?; let mut packet = Packet::default(); // packet_type let id_char = payload.chars().next().ok_or(Error::IncompletePacket())?; packet.packet_type = PacketId::try_from(id_char)?; + payload = &payload[id_char.len_utf8()..]; // attachment_count @@ -277,6 +201,123 @@ impl TryFrom<&Bytes> for Packet { } } +impl Packet { + /// Returns a packet for a payload, could be used for both binary and non binary + /// events and acks. Convenience method. + #[inline] + pub(crate) fn new_from_payload<'a>( + payload: Payload, + event: Event, + nsp: &'a str, + id: Option, + ) -> Result { + match payload { + Payload::Binary(bin_data) => Ok(Packet::new( + if id.is_some() { + PacketId::BinaryAck + } else { + PacketId::BinaryEvent + }, + nsp.to_owned(), + Some(serde_json::Value::String(event.into()).to_string()), + id, + 1, + Some(vec![bin_data]), + )), + #[allow(deprecated)] + Payload::String(str_data) => { + let payload = if serde_json::from_str::(&str_data).is_ok() { + format!("[\"{event}\",{str_data}]") + } else { + format!("[\"{event}\",\"{str_data}\"]") + }; + + Ok(Packet::new( + PacketId::Event, + nsp.to_owned(), + Some(payload), + id, + 0, + None, + )) + } + Payload::Text(mut data) => { + let mut payload_args = vec![serde_json::Value::String(event.to_string())]; + payload_args.append(&mut data); + drop(data); + + let payload = serde_json::Value::Array(payload_args).to_string(); + + Ok(Packet::new( + PacketId::Event, + nsp.to_owned(), + Some(payload), + id, + 0, + None, + )) + } + } + } +} + +impl Default for Packet { + fn default() -> Self { + Self { + packet_type: PacketId::Event, + nsp: String::from("/"), + data: None, + id: None, + attachment_count: 0, + attachments: None, + } + } +} + +impl TryFrom for PacketId { + type Error = Error; + fn try_from(b: u8) -> Result { + PacketId::try_from(b as char) + } +} + +impl TryFrom for PacketId { + type Error = Error; + fn try_from(b: char) -> Result { + match b { + '0' => Ok(PacketId::Connect), + '1' => Ok(PacketId::Disconnect), + '2' => Ok(PacketId::Event), + '3' => Ok(PacketId::Ack), + '4' => Ok(PacketId::ConnectError), + '5' => Ok(PacketId::BinaryEvent), + '6' => Ok(PacketId::BinaryAck), + _ => Err(Error::InvalidPacketId(b)), + } + } +} + +impl Packet { + /// Creates an instance. + pub const fn new( + packet_type: PacketId, + nsp: String, + data: Option, + id: Option, + attachment_count: u8, + attachments: Option>, + ) -> Self { + Packet { + packet_type, + nsp, + data, + id, + attachment_count, + attachments, + } + } +} + #[cfg(test)] mod test { use super::*; @@ -286,7 +327,7 @@ mod test { /// https://github.com/socketio/socket.io-protocol fn test_decode() { let payload = Bytes::from_static(b"0{\"token\":\"123\"}"); - let packet = Packet::try_from(&payload); + let packet = PacketParser::default_decode(&payload); assert!(packet.is_ok()); assert_eq!( @@ -304,7 +345,7 @@ mod test { let utf8_data = "{\"token™\":\"123\"}".to_owned(); let utf8_payload = format!("0/admin™,{}", utf8_data); let payload = Bytes::from(utf8_payload); - let packet = Packet::try_from(&payload); + let packet = PacketParser::default_decode(&payload); assert!(packet.is_ok()); assert_eq!( @@ -320,7 +361,7 @@ mod test { ); let payload = Bytes::from_static(b"1/admin,"); - let packet = Packet::try_from(&payload); + let packet = PacketParser::default_decode(&payload); assert!(packet.is_ok()); assert_eq!( @@ -336,7 +377,7 @@ mod test { ); let payload = Bytes::from_static(b"2[\"hello\",1]"); - let packet = Packet::try_from(&payload); + let packet = PacketParser::default_decode(&payload); assert!(packet.is_ok()); assert_eq!( @@ -352,7 +393,7 @@ mod test { ); let payload = Bytes::from_static(b"2/admin,456[\"project:delete\",123]"); - let packet = Packet::try_from(&payload); + let packet = PacketParser::default_decode(&payload); assert!(packet.is_ok()); assert_eq!( @@ -368,7 +409,7 @@ mod test { ); let payload = Bytes::from_static(b"3/admin,456[]"); - let packet = Packet::try_from(&payload); + let packet = PacketParser::default_decode(&payload); assert!(packet.is_ok()); assert_eq!( @@ -384,7 +425,7 @@ mod test { ); let payload = Bytes::from_static(b"4/admin,{\"message\":\"Not authorized\"}"); - let packet = Packet::try_from(&payload); + let packet = PacketParser::default_decode(&payload); assert!(packet.is_ok()); assert_eq!( @@ -400,7 +441,7 @@ mod test { ); let payload = Bytes::from_static(b"51-[\"hello\",{\"_placeholder\":true,\"num\":0}]"); - let packet = Packet::try_from(&payload); + let packet = PacketParser::default_decode(&payload); assert!(packet.is_ok()); assert_eq!( @@ -418,7 +459,7 @@ mod test { let payload = Bytes::from_static( b"51-/admin,456[\"project:delete\",{\"_placeholder\":true,\"num\":0}]", ); - let packet = Packet::try_from(&payload); + let packet = PacketParser::default_decode(&payload); assert!(packet.is_ok()); assert_eq!( @@ -434,7 +475,7 @@ mod test { ); let payload = Bytes::from_static(b"61-/admin,456[{\"_placeholder\":true,\"num\":0}]"); - let packet = Packet::try_from(&payload); + let packet = PacketParser::default_decode(&payload); assert!(packet.is_ok()); assert_eq!( @@ -464,7 +505,7 @@ mod test { ); assert_eq!( - Bytes::from(&packet), + PacketParser::default_encode(&packet), "0{\"token\":\"123\"}".to_string().into_bytes() ); @@ -478,7 +519,7 @@ mod test { ); assert_eq!( - Bytes::from(&packet), + PacketParser::default_encode(&packet), "0/admin,{\"token\":\"123\"}".to_string().into_bytes() ); @@ -491,7 +532,10 @@ mod test { None, ); - assert_eq!(Bytes::from(&packet), "1/admin,".to_string().into_bytes()); + assert_eq!( + PacketParser::default_encode(&packet), + "1/admin,".to_string().into_bytes() + ); let packet = Packet::new( PacketId::Event, @@ -503,7 +547,7 @@ mod test { ); assert_eq!( - Bytes::from(&packet), + PacketParser::default_encode(&packet), "2[\"hello\",1]".to_string().into_bytes() ); @@ -517,7 +561,7 @@ mod test { ); assert_eq!( - Bytes::from(&packet), + PacketParser::default_encode(&packet), "2/admin,456[\"project:delete\",123]" .to_string() .into_bytes() @@ -533,7 +577,7 @@ mod test { ); assert_eq!( - Bytes::from(&packet), + PacketParser::default_encode(&packet), "3/admin,456[]".to_string().into_bytes() ); @@ -547,7 +591,7 @@ mod test { ); assert_eq!( - Bytes::from(&packet), + PacketParser::default_encode(&packet), "4/admin,{\"message\":\"Not authorized\"}" .to_string() .into_bytes() @@ -563,7 +607,7 @@ mod test { ); assert_eq!( - Bytes::from(&packet), + PacketParser::default_encode(&packet), "51-[\"hello\",{\"_placeholder\":true,\"num\":0}]" .to_string() .into_bytes() @@ -579,7 +623,7 @@ mod test { ); assert_eq!( - Bytes::from(&packet), + PacketParser::default_encode(&packet), "51-/admin,456[\"project:delete\",{\"_placeholder\":true,\"num\":0}]" .to_string() .into_bytes() @@ -595,7 +639,7 @@ mod test { ); assert_eq!( - Bytes::from(&packet), + PacketParser::default_encode(&packet), "61-/admin,456[{\"_placeholder\":true,\"num\":0}]" .to_string() .into_bytes() diff --git a/socketio/src/socket.rs b/socketio/src/socket.rs index b881bad0..36b3c3d9 100644 --- a/socketio/src/socket.rs +++ b/socketio/src/socket.rs @@ -1,28 +1,28 @@ use crate::error::{Error, Result}; -use crate::packet::{Packet, PacketId}; -use bytes::Bytes; +use crate::event::Event; +use crate::packet::{Packet, PacketId, PacketParser}; +use crate::payload::Payload; use rust_engineio::{Client as EngineClient, Packet as EnginePacket, PacketId as EnginePacketId}; -use std::convert::TryFrom; use std::sync::{atomic::AtomicBool, Arc}; use std::{fmt::Debug, sync::atomic::Ordering}; -use super::{event::Event, payload::Payload}; - /// Handles communication in the `socket.io` protocol. #[derive(Clone, Debug)] pub(crate) struct Socket { - //TODO: 0.4.0 refactor this + // TODO: 0.4.0 refactor this engine_client: Arc, connected: Arc, + packet_parser: PacketParser, } impl Socket { /// Creates an instance of `Socket`. - pub(super) fn new(engine_client: EngineClient) -> Result { + pub(super) fn new(engine_client: EngineClient, packet_parser: PacketParser) -> Result { Ok(Socket { engine_client: Arc::new(engine_client), connected: Arc::new(AtomicBool::default()), + packet_parser, }) } @@ -57,7 +57,8 @@ impl Socket { } // the packet, encoded as an engine.io message packet - let engine_packet = EnginePacket::new(EnginePacketId::Message, Bytes::from(&packet)); + let engine_packet = + EnginePacket::new(EnginePacketId::Message, self.packet_parser.encode(&packet)); self.engine_client.emit(engine_packet)?; if let Some(attachments) = packet.attachments { @@ -119,7 +120,7 @@ impl Socket { /// Handles new incoming engineio packets fn handle_engineio_packet(&self, packet: EnginePacket) -> Result { - let mut socket_packet = Packet::try_from(&packet.data)?; + let mut socket_packet = self.packet_parser.decode(&packet.data)?; // Only handle attachments if there are any if socket_packet.attachment_count > 0 {