Skip to content

Commit

Permalink
feat(socketio/packet): switch to Str type for ns path storage (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
Totodore authored Jun 13, 2024
1 parent af25013 commit dcfa088
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 31 deletions.
19 changes: 19 additions & 0 deletions engineioxide/src/str.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use bytes::Bytes;

/// A custom [`Bytes`] wrapper to efficiently store string packets
Expand Down Expand Up @@ -46,6 +48,23 @@ impl From<String> for Str {
}
}

impl From<Cow<'static, str>> for Str {
fn from(s: Cow<'static, str>) -> Self {
match s {
Cow::Borrowed(s) => Str::from(s),
Cow::Owned(s) => Str::from(s),
}
}
}
impl From<&Cow<'static, str>> for Str {
fn from(s: &Cow<'static, str>) -> Self {
match s {
Cow::Borrowed(s) => Str::from(*s),
Cow::Owned(s) => Str(Bytes::copy_from_slice(s.as_bytes())),
}
}
}

impl From<Str> for Bytes {
fn from(s: Str) -> Self {
s.0
Expand Down
8 changes: 4 additions & 4 deletions socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ impl<A: Adapter> Client<A> {
fn sock_connect(
&self,
auth: Option<String>,
ns_path: &str,
ns_path: Str,
esocket: &Arc<engineioxide::Socket<SocketData<A>>>,
) -> Result<(), Error> {
#[cfg(feature = "tracing")]
tracing::debug!("auth: {:?}", auth);

if let Some(ns) = self.get_ns(ns_path) {
if let Some(ns) = self.get_ns(&ns_path) {
let esocket = esocket.clone();
tokio::spawn(async move {
if ns.connect(esocket.id, esocket.clone(), auth).await.is_ok() {
Expand Down Expand Up @@ -246,7 +246,7 @@ impl<A: Adapter> EngineIoHandler for Client<A> {
if protocol == ProtocolVersion::V4 {
#[cfg(feature = "tracing")]
tracing::debug!("connecting to default namespace for v4");
self.sock_connect(None, "/", &socket).unwrap();
self.sock_connect(None, Str::from("/"), &socket).unwrap();
}

if protocol == ProtocolVersion::V5 {
Expand Down Expand Up @@ -299,7 +299,7 @@ impl<A: Adapter> EngineIoHandler for Client<A> {

let res: Result<(), Error> = match packet.inner {
PacketData::Connect(auth) => self
.sock_connect(auth, &packet.ns, &socket)
.sock_connect(auth, packet.ns, &socket)
.map_err(Into::into),
PacketData::BinaryEvent(_, _, _) | PacketData::BinaryAck(_, _) => {
// Cache-in the socket data until all the binary payloads are received
Expand Down
2 changes: 1 addition & 1 deletion socketioxide/src/extract/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl<A: Adapter> AckSender<A> {
return Err(e.with_value(data).into());
}
};
let ns = self.socket.ns();
let ns = &self.socket.ns.path;
let data = serde_json::to_value(data)?;
let packet = if self.binary.is_empty() {
Packet::ack(ns, data, ack_id)
Expand Down
44 changes: 22 additions & 22 deletions socketioxide/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,70 +18,70 @@ pub struct Packet<'a> {
/// The packet data
pub inner: PacketData<'a>,
/// The namespace the packet belongs to
pub ns: Cow<'a, str>,
pub ns: Str,
}

impl<'a> Packet<'a> {
/// Send a connect packet with a default payload for v5 and no payload for v4
pub fn connect(
ns: &'a str,
ns: impl Into<Str>,
#[allow(unused_variables)] sid: Sid,
#[allow(unused_variables)] protocol: ProtocolVersion,
) -> Self {
#[cfg(not(feature = "v4"))]
{
Self::connect_v5(ns, sid)
Self::connect_v5(ns.into(), sid)
}

#[cfg(feature = "v4")]
{
match protocol {
ProtocolVersion::V4 => Self::connect_v4(ns),
ProtocolVersion::V5 => Self::connect_v5(ns, sid),
ProtocolVersion::V4 => Self::connect_v4(ns.into()),
ProtocolVersion::V5 => Self::connect_v5(ns.into(), sid),
}
}
}

/// Sends a connect packet without payload.
#[cfg(feature = "v4")]
fn connect_v4(ns: &'a str) -> Self {
fn connect_v4(ns: Str) -> Self {
Self {
inner: PacketData::Connect(None),
ns: Cow::Borrowed(ns),
ns,
}
}

/// Sends a connect packet with payload.
fn connect_v5(ns: &'a str, sid: Sid) -> Self {
fn connect_v5(ns: Str, sid: Sid) -> Self {
let val = serde_json::to_string(&ConnectPacket { sid }).unwrap();
Self {
inner: PacketData::Connect(Some(val)),
ns: Cow::Borrowed(ns),
ns,
}
}

/// Create a disconnect packet for the given namespace
pub fn disconnect(ns: &'a str) -> Self {
pub fn disconnect(ns: impl Into<Str>) -> Self {
Self {
inner: PacketData::Disconnect,
ns: Cow::Borrowed(ns),
ns: ns.into(),
}
}
}

impl<'a> Packet<'a> {
/// Create a connect error packet for the given namespace with a message
pub fn connect_error(ns: &'a str, message: &str) -> Self {
pub fn connect_error(ns: impl Into<Str>, message: &str) -> Self {
let message = serde_json::to_string(message).unwrap();
let packet = format!(r#"{{"message":{}}}"#, message);
Self {
inner: PacketData::ConnectError(packet),
ns: Cow::Borrowed(ns),
ns: ns.into(),
}
}

/// Create an event packet for the given namespace
pub fn event(ns: impl Into<Cow<'a, str>>, e: impl Into<Cow<'a, str>>, data: Value) -> Self {
pub fn event(ns: impl Into<Str>, e: impl Into<Cow<'a, str>>, data: Value) -> Self {
Self {
inner: PacketData::Event(e.into(), data, None),
ns: ns.into(),
Expand All @@ -90,7 +90,7 @@ impl<'a> Packet<'a> {

/// Create a binary event packet for the given namespace
pub fn bin_event(
ns: impl Into<Cow<'a, str>>,
ns: impl Into<Str>,
e: impl Into<Cow<'a, str>>,
data: Value,
bin: Vec<Bytes>,
Expand All @@ -105,20 +105,20 @@ impl<'a> Packet<'a> {
}

/// Create an ack packet for the given namespace
pub fn ack(ns: &'a str, data: Value, ack: i64) -> Self {
pub fn ack(ns: impl Into<Str>, data: Value, ack: i64) -> Self {
Self {
inner: PacketData::EventAck(data, ack),
ns: Cow::Borrowed(ns),
ns: ns.into(),
}
}

/// Create a binary ack packet for the given namespace
pub fn bin_ack(ns: &'a str, data: Value, bin: Vec<Bytes>, ack: i64) -> Self {
pub fn bin_ack(ns: impl Into<Str>, data: Value, bin: Vec<Bytes>, ack: i64) -> Self {
debug_assert!(!bin.is_empty());
let packet = BinaryPacket::outgoing(data, bin);
Self {
inner: PacketData::BinaryAck(packet, ack),
ns: Cow::Borrowed(ns),
ns: ns.into(),
}
}

Expand Down Expand Up @@ -466,19 +466,19 @@ impl<'a> TryFrom<Str> for Packet<'a> {
match chars.get(i) {
Some(b',') => {
i += 1;
break Cow::Owned(value[start_index..i - 1].to_string());
break value.slice(start_index..i - 1);
}
// It maybe possible depending on clients that ns does not end with a comma
// if it is the end of the packet
// e.g `1/custom`
None => {
break Cow::Owned(value[start_index..i].to_string());
break value.slice(start_index..i);
}
Some(_) => i += 1,
}
}
} else {
Cow::Borrowed("/")
Str::from("/")
};

let start_index = i;
Expand Down
5 changes: 3 additions & 2 deletions socketioxide/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ impl<A: Adapter> Socket<A> {
}
};

let ns = self.ns();
let ns = &self.ns.path;
let data = serde_json::to_value(data)?;
permit.send(Packet::event(ns, event.into(), data));
Ok(())
Expand Down Expand Up @@ -392,8 +392,9 @@ impl<A: Adapter> Socket<A> {
return Err(e.with_value(data).into());
}
};
let ns = &self.ns.path;
let data = serde_json::to_value(data)?;
let packet = Packet::event(self.ns(), event.into(), data);
let packet = Packet::event(ns, event.into(), data);
let rx = self.send_with_ack_permit(packet, permit);
let stream = AckInnerStream::send(rx, self.get_io().config().ack_timeout, self.id);
Ok(AckStream::<V>::from(stream))
Expand Down
6 changes: 5 additions & 1 deletion socketioxide/tests/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ use socketioxide::{
};
use tokio::sync::mpsc;

fn create_msg(ns: &str, event: &str, data: impl Into<serde_json::Value>) -> engineioxide::Packet {
fn create_msg(
ns: &'static str,
event: &str,
data: impl Into<serde_json::Value>,
) -> engineioxide::Packet {
let packet: String = Packet::event(ns, event, data.into()).into();
Message(packet.into())
}
Expand Down
2 changes: 1 addition & 1 deletion socketioxide/tests/extractors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async fn timeout_rcv_err<T: std::fmt::Debug>(srx: &mut tokio::sync::mpsc::Receiv
.unwrap_err();
}

fn create_msg(ns: &str, event: &str, data: impl Into<serde_json::Value>) -> EioPacket {
fn create_msg(ns: &'static str, event: &str, data: impl Into<serde_json::Value>) -> EioPacket {
let packet: String = Packet::event(ns, event, data.into()).into();
EioPacket::Message(packet.into())
}
Expand Down

0 comments on commit dcfa088

Please sign in to comment.