diff --git a/Cargo.lock b/Cargo.lock index c2da666c..714cd82c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -228,7 +228,7 @@ dependencies = [ "criterion-plot", "futures", "is-terminal", - "itertools", + "itertools 0.10.5", "num-traits", "once_cell", "oorandom", @@ -250,7 +250,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", - "itertools", + "itertools 0.10.5", ] [[package]] @@ -707,6 +707,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.6" @@ -1252,7 +1261,7 @@ dependencies = [ "futures", "http", "http-body", - "itertools", + "itertools 0.11.0", "serde", "serde_json", "thiserror", diff --git a/examples/src/chat/handlers.rs b/examples/src/chat/handlers.rs index 42fadbd2..4d9b0b40 100644 --- a/examples/src/chat/handlers.rs +++ b/examples/src/chat/handlers.rs @@ -18,7 +18,7 @@ pub async fn handler(socket: Arc>) { info!("Nickname: {:?}", data.nickname); socket.extensions.insert(data.nickname); socket.emit("message", "Welcome to the chat!").ok(); - socket.join("default"); + socket.join("default").unwrap(); } else { info!("No nickname provided, disconnecting..."); socket.disconnect().ok(); @@ -30,8 +30,8 @@ pub async fn handler(socket: Arc>) { |socket, (room, message): (String, String), _, _| async move { let Nickname(ref nickname) = *socket.extensions.get().unwrap(); info!("transfering message from {nickname} to {room}: {message}"); - info!("Sockets in room: {:?}", socket.local().sockets()); - if let Some(dest) = socket.to("default").sockets().iter().find(|s| { + info!("Sockets in room: {:?}", socket.local().sockets().unwrap()); + if let Some(dest) = socket.to("default").sockets().unwrap().iter().find(|s| { s.extensions .get::() .map(|n| n.0 == room) @@ -51,12 +51,12 @@ pub async fn handler(socket: Arc>) { socket.on("join", |socket, room: String, _, _| async move { info!("Joining room {}", room); - socket.join(room); + socket.join(room).unwrap(); }); socket.on("leave", |socket, room: String, _, _| async move { info!("Leaving room {}", room); - socket.leave(room); + socket.leave(room).unwrap(); }); socket.on("list", |socket, room: Option, _, _| async move { @@ -65,6 +65,7 @@ pub async fn handler(socket: Arc>) { let sockets = socket .within(room) .sockets() + .unwrap() .iter() .filter_map(|s| s.extensions.get::()) .fold("".to_string(), |a, b| a + &b.0 + ", ") @@ -72,7 +73,7 @@ pub async fn handler(socket: Arc>) { .to_string(); socket.emit("message", sockets).ok(); } else { - let rooms = socket.rooms(); + let rooms = socket.rooms().unwrap(); info!("Listing rooms: {:?}", &rooms); socket.emit("message", rooms).ok(); } diff --git a/socketioxide/src/adapter.rs b/socketioxide/src/adapter.rs index fc8eba4c..e2d21be0 100644 --- a/socketioxide/src/adapter.rs +++ b/socketioxide/src/adapter.rs @@ -5,6 +5,7 @@ use std::{ collections::{HashMap, HashSet}, + convert::Infallible, sync::{Arc, RwLock, Weak}, time::Duration, }; @@ -18,15 +19,12 @@ use itertools::Itertools; use serde::de::DeserializeOwned; use crate::{ - errors::{ - AckError, - BroadcastError, - }, + errors::{AckError, BroadcastError, AdapterError}, handler::AckResponse, ns::Namespace, operators::RoomParam, packet::Packet, - socket::Socket + socket::Socket, }; /// A room identifier @@ -68,53 +66,64 @@ impl BroadcastOptions { //TODO: Make an AsyncAdapter trait pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { + type Error: std::error::Error + Into + Send + 'static; + /// Create a new adapter and give the namespace ref to retrieve sockets. fn new(ns: Weak>) -> Self where Self: Sized; /// Initialize the adapter. - fn init(&self); + fn init(&self) -> Result<(), Self::Error>; /// Close the adapter. - fn close(&self); + fn close(&self) -> Result<(), Self::Error>; /// Return the number of servers. - fn server_count(&self) -> u16; + fn server_count(&self) -> Result; /// Add the socket to all the rooms. - fn add_all(&self, sid: Sid, rooms: impl RoomParam); + fn add_all(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Self::Error>; /// Remove the socket from the rooms. - fn del(&self, sid: Sid, rooms: impl RoomParam); + fn del(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Self::Error>; /// Remove the socket from all the rooms. - fn del_all(&self, sid: Sid); + fn del_all(&self, sid: Sid) -> Result<(), Self::Error>; /// Broadcast the packet to the sockets that match the [`BroadcastOptions`]. - fn broadcast(&self, packet: Packet, opts: BroadcastOptions) -> Result<(), BroadcastError>; + fn broadcast( + &self, + packet: Packet, + opts: BroadcastOptions, + ) -> Result<(), BroadcastError>; /// Broadcast the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses. fn broadcast_with_ack( &self, packet: Packet, opts: BroadcastOptions, - ) -> BoxStream<'static, Result, AckError>>; + ) -> Result, AckError>>, BroadcastError>; /// Return the sockets ids that match the [`BroadcastOptions`]. - fn sockets(&self, rooms: impl RoomParam) -> Vec; + fn sockets(&self, rooms: impl RoomParam) -> Result, Self::Error>; /// Return the rooms of the socket. - fn socket_rooms(&self, sid: Sid) -> Vec; + fn socket_rooms(&self, sid: Sid) -> Result, Self::Error>; /// Return the sockets that match the [`BroadcastOptions`]. - fn fetch_sockets(&self, opts: BroadcastOptions) -> Vec>> + fn fetch_sockets(&self, opts: BroadcastOptions) -> Result>>, Self::Error> where Self: Sized; /// Add the sockets that match the [`BroadcastOptions`] to the rooms. - fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam); + fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) + -> Result<(), Self::Error>; /// Remove the sockets that match the [`BroadcastOptions`] from the rooms. - fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam); + fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) + -> Result<(), Self::Error>; /// Disconnect the sockets that match the [`BroadcastOptions`]. - fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError>; + fn disconnect_socket( + &self, + opts: BroadcastOptions, + ) -> Result<(), BroadcastError>; //TODO: implement // fn server_side_emit(&self, packet: Packet, opts: BroadcastOptions) -> Result; @@ -129,7 +138,15 @@ pub struct LocalAdapter { ns: Weak>, } +impl From for AdapterError { + fn from(_: Infallible) -> AdapterError { + unreachable!() + } +} + impl Adapter for LocalAdapter { + type Error = Infallible; + fn new(ns: Weak>) -> Self { Self { rooms: HashMap::new().into(), @@ -137,15 +154,15 @@ impl Adapter for LocalAdapter { } } - fn init(&self) {} + fn init(&self) -> Result<(), Infallible> { Ok(()) } - fn close(&self) {} + fn close(&self) -> Result<(), Infallible> { Ok(()) } - fn server_count(&self) -> u16 { - 1 + fn server_count(&self) -> Result { + Ok(1) } - fn add_all(&self, sid: Sid, rooms: impl RoomParam) { + fn add_all(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Infallible> { let mut rooms_map = self.rooms.write().unwrap(); for room in rooms.into_room_iter() { rooms_map @@ -153,22 +170,25 @@ impl Adapter for LocalAdapter { .or_insert_with(HashSet::new) .insert(sid); } + Ok(()) } - fn del(&self, sid: Sid, rooms: impl RoomParam) { + fn del(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Infallible> { let mut rooms_map = self.rooms.write().unwrap(); for room in rooms.into_room_iter() { if let Some(room) = rooms_map.get_mut(&room) { room.remove(&sid); } } + Ok(()) } - fn del_all(&self, sid: Sid) { + fn del_all(&self, sid: Sid) -> Result<(), Infallible> { let mut rooms_map = self.rooms.write().unwrap(); for room in rooms_map.values_mut() { room.remove(&sid); } + Ok(()) } fn broadcast(&self, packet: Packet, opts: BroadcastOptions) -> Result<(), BroadcastError> { @@ -190,7 +210,7 @@ impl Adapter for LocalAdapter { &self, packet: Packet, opts: BroadcastOptions, - ) -> BoxStream<'static, Result, AckError>> { + ) -> Result, AckError>>, BroadcastError> { let duration = opts.flags.iter().find_map(|flag| match flag { BroadcastFlags::Timeout(duration) => Some(*duration), _ => None, @@ -206,44 +226,46 @@ impl Adapter for LocalAdapter { let packet = packet.clone(); async move { socket.clone().send_with_ack(packet, duration).await } }); - stream::iter(ack_futs).buffer_unordered(count).boxed() + Ok(stream::iter(ack_futs).buffer_unordered(count).boxed()) } - fn sockets(&self, rooms: impl RoomParam) -> Vec { + fn sockets(&self, rooms: impl RoomParam) -> Result, Infallible> { let mut opts = BroadcastOptions::new(0i64.into()); opts.rooms.extend(rooms.into_room_iter()); - self.apply_opts(opts) + Ok(self.apply_opts(opts) .into_iter() .map(|socket| socket.sid) - .collect() + .collect()) } //TODO: make this operation O(1) - fn socket_rooms(&self, sid: Sid) -> Vec { + fn socket_rooms(&self, sid: Sid) -> Result, Infallible> { let rooms_map = self.rooms.read().unwrap(); - rooms_map + Ok(rooms_map .iter() .filter(|(_, sockets)| sockets.contains(&sid)) .map(|(room, _)| room.clone()) - .collect() + .collect()) } - fn fetch_sockets(&self, opts: BroadcastOptions) -> Vec>> { - self.apply_opts(opts) + fn fetch_sockets(&self, opts: BroadcastOptions) -> Result>>, Infallible> { + Ok(self.apply_opts(opts)) } - fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) { + fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) -> Result<(), Infallible> { let rooms: Vec = rooms.into_room_iter().collect(); for socket in self.apply_opts(opts) { - self.add_all(socket.sid, rooms.clone()); + self.add_all(socket.sid, rooms.clone()).unwrap(); } + Ok(()) } - fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) { + fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) -> Result<(), Infallible> { let rooms: Vec = rooms.into_room_iter().collect(); for socket in self.apply_opts(opts) { - self.del(socket.sid, rooms.clone()); + self.del(socket.sid, rooms.clone()).unwrap(); } + Ok(()) } fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> { @@ -313,7 +335,7 @@ mod test { async fn test_server_count() { let ns = Namespace::new_dummy([]); let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - assert_eq!(adapter.server_count(), 1); + assert_eq!(adapter.server_count().unwrap(), 1); } #[tokio::test] @@ -321,7 +343,7 @@ mod test { let socket: Sid = 1i64.into(); let ns = Namespace::new_dummy([socket]); let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]); + adapter.add_all(socket, ["room1", "room2"]).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); assert_eq!(rooms_map.get("room1").unwrap().len(), 1); @@ -333,8 +355,8 @@ mod test { let socket: Sid = 1i64.into(); let ns = Namespace::new_dummy([socket]); let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]); - adapter.del(socket, "room1"); + adapter.add_all(socket, ["room1", "room2"]).unwrap(); + adapter.del(socket, "room1").unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); assert_eq!(rooms_map.get("room1").unwrap().len(), 0); @@ -346,8 +368,8 @@ mod test { let socket: Sid = 1i64.into(); let ns = Namespace::new_dummy([socket]); let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]); - adapter.del_all(socket); + adapter.add_all(socket, ["room1", "room2"]).unwrap(); + adapter.del_all(socket).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); assert_eq!(rooms_map.get("room1").unwrap().len(), 0); @@ -358,13 +380,13 @@ mod test { async fn test_socket_room() { let ns = Namespace::new_dummy([1i64, 2, 3].map(Into::into)); let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(1i64.into(), ["room1", "room2"]); - adapter.add_all(2i64.into(), ["room1"]); - adapter.add_all(3i64.into(), ["room2"]); - assert!(adapter.socket_rooms(1i64.into()).contains(&"room1".into())); - assert!(adapter.socket_rooms(1i64.into()).contains(&"room2".into())); - assert_eq!(adapter.socket_rooms(2i64.into()), ["room1"]); - assert_eq!(adapter.socket_rooms(3i64.into()), ["room2"]); + adapter.add_all(1i64.into(), ["room1", "room2"]).unwrap(); + adapter.add_all(2i64.into(), ["room1"]).unwrap(); + adapter.add_all(3i64.into(), ["room2"]).unwrap(); + assert!(adapter.socket_rooms(1i64.into()).unwrap().contains(&"room1".into())); + assert!(adapter.socket_rooms(1i64.into()).unwrap().contains(&"room2".into())); + assert_eq!(adapter.socket_rooms(2i64.into()).unwrap(), ["room1"]); + assert_eq!(adapter.socket_rooms(3i64.into()).unwrap(), ["room2"]); } #[tokio::test] @@ -372,11 +394,11 @@ mod test { let socket: Sid = 0i64.into(); let ns = Namespace::new_dummy([socket]); let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1"]); + adapter.add_all(socket, ["room1"]).unwrap(); let mut opts = BroadcastOptions::new(socket); opts.rooms = vec!["room1".to_string()]; - adapter.add_sockets(opts, "room2"); + adapter.add_sockets(opts, "room2").unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); @@ -389,11 +411,11 @@ mod test { let socket: Sid = 0i64.into(); let ns = Namespace::new_dummy([socket]); let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1"]); + adapter.add_all(socket, ["room1"]).unwrap(); let mut opts = BroadcastOptions::new(socket); opts.rooms = vec!["room1".to_string()]; - adapter.add_sockets(opts, "room2"); + adapter.add_sockets(opts, "room2").unwrap(); { let rooms_map = adapter.rooms.read().unwrap(); @@ -405,7 +427,7 @@ mod test { let mut opts = BroadcastOptions::new(socket); opts.rooms = vec!["room1".to_string()]; - adapter.del_sockets(opts, "room2"); + adapter.del_sockets(opts, "room2").unwrap(); { let rooms_map = adapter.rooms.read().unwrap(); @@ -423,21 +445,21 @@ mod test { let socket2: Sid = 2i64.into(); let ns = Namespace::new_dummy([socket0, socket1, socket2]); let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket0, ["room1", "room2"]); - adapter.add_all(socket1, ["room1", "room3"]); - adapter.add_all(socket2, ["room2", "room3"]); + adapter.add_all(socket0, ["room1", "room2"]).unwrap(); + adapter.add_all(socket1, ["room1", "room3"]).unwrap(); + adapter.add_all(socket2, ["room2", "room3"]).unwrap(); - let sockets = adapter.sockets("room1"); + let sockets = adapter.sockets("room1").unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket0)); assert!(sockets.contains(&socket1)); - let sockets = adapter.sockets("room2"); + let sockets = adapter.sockets("room2").unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket0)); assert!(sockets.contains(&socket2)); - let sockets = adapter.sockets("room3"); + let sockets = adapter.sockets("room3").unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket1)); assert!(sockets.contains(&socket2)); @@ -450,9 +472,9 @@ mod test { let socket2: Sid = 2i64.into(); let ns = Namespace::new_dummy([socket0, socket1, socket2]); let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket0, ["room1", "room2", "room4"]); - adapter.add_all(socket1, ["room1", "room3", "room5"]); - adapter.add_all(socket2, ["room2", "room3", "room6"]); + adapter.add_all(socket0, ["room1", "room2", "room4"]).unwrap(); + adapter.add_all(socket1, ["room1", "room3", "room5"]).unwrap(); + adapter.add_all(socket2, ["room2", "room3", "room6"]).unwrap(); let mut opts = BroadcastOptions::new(socket0); opts.rooms = vec!["room5".to_string()]; @@ -465,7 +487,7 @@ mod test { ), } - let sockets = adapter.sockets("room2"); + let sockets = adapter.sockets("room2").unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket2)); assert!(sockets.contains(&socket0)); @@ -478,33 +500,33 @@ mod test { let ns = Namespace::new_dummy([socket0, socket1, socket2]); let adapter = LocalAdapter::new(Arc::downgrade(&ns)); // Add socket 0 to room1 and room2 - adapter.add_all(socket0, ["room1", "room2"]); + adapter.add_all(socket0, ["room1", "room2"]).unwrap(); // Add socket 1 to room1 and room3 - adapter.add_all(socket1, ["room1", "room3"]); + adapter.add_all(socket1, ["room1", "room3"]).unwrap(); // Add socket 2 to room2 and room3 - adapter.add_all(socket2, ["room1", "room2", "room3"]); + adapter.add_all(socket2, ["room1", "room2", "room3"]).unwrap(); // socket 2 is the sender let mut opts = BroadcastOptions::new(socket2); opts.rooms = vec!["room1".to_string()]; opts.except = vec!["room2".to_string()]; - let sockets = adapter.fetch_sockets(opts); + let sockets = adapter.fetch_sockets(opts).unwrap(); assert_eq!(sockets.len(), 1); assert_eq!(sockets[0].sid, socket1); let mut opts = BroadcastOptions::new(socket2); opts.flags.insert(BroadcastFlags::Broadcast); opts.except = vec!["room2".to_string()]; - let sockets = adapter.fetch_sockets(opts); + let sockets = adapter.fetch_sockets(opts).unwrap(); assert_eq!(sockets.len(), 1); let opts = BroadcastOptions::new(socket2); - let sockets = adapter.fetch_sockets(opts); + let sockets = adapter.fetch_sockets(opts).unwrap(); assert_eq!(sockets.len(), 1); assert_eq!(sockets[0].sid, socket2); let opts = BroadcastOptions::new(10000i64.into()); - let sockets = adapter.fetch_sockets(opts); + let sockets = adapter.fetch_sockets(opts).unwrap(); assert_eq!(sockets.len(), 0); } } diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index 6e05af61..0f6503bd 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -121,7 +121,9 @@ impl EngineIoHandler for Client { fn on_disconnect(&self, socket: &EIoSocket) { debug!("eio socket disconnect {}", socket.sid); self.ns.values().for_each(|ns| { - ns.disconnect(socket.sid).ok(); + if let Err(e) = ns.remove_socket(socket.sid) { + error!("Adapter error when disconnecting {}: {}, in a multiple server scenario it could leads to desyncronisation issues", socket.sid, e); + } }); } @@ -185,4 +187,4 @@ impl Clone for Client { ns: self.ns.clone(), } } -} \ No newline at end of file +} diff --git a/socketioxide/src/errors.rs b/socketioxide/src/errors.rs index f4af940d..1804a95c 100644 --- a/socketioxide/src/errors.rs +++ b/socketioxide/src/errors.rs @@ -1,6 +1,6 @@ use crate::retryer::Retryer; use engineioxide::sid_generator::Sid; -use std::fmt::Debug; +use std::fmt::{Debug, Display}; use tokio::sync::oneshot; /// Error type for socketio @@ -24,6 +24,9 @@ pub enum Error { /// An engineio error #[error("engineio error: {0}")] EngineIoError(#[from] engineioxide::errors::Error), + + #[error("adapter error: {0}")] + Adapter(#[from] AdapterError), } /// Error type for ack responses @@ -59,6 +62,9 @@ pub enum BroadcastError { /// An error occurred while serializing the JSON packet. #[error("Error serializing JSON packet: {0:?}")] Serialize(#[from] serde_json::Error), + + #[error("Adapter error: {0}")] + Adapter(#[from] AdapterError), } impl From> for BroadcastError { @@ -85,6 +91,9 @@ pub enum SendError { /// An error occurred during the retry process in the `Retryer`. #[error("Send error: {0:?}")] RetryerError(#[from] RetryerError), + + #[error("Adapter error: {0}")] + AdapterError(#[from] AdapterError), } /// Error type for the `Retryer` struct indicating various failure scenarios during the retry process. @@ -97,3 +106,12 @@ pub enum RetryerError { #[error("Sent to a full socket channel")] Remaining(Retryer), } + +/// Error type for the [`Adapter`] trait. +#[derive(Debug, thiserror::Error)] +pub struct AdapterError(#[from] pub Box); +impl Display for AdapterError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self.0, f) + } +} diff --git a/socketioxide/src/ns.rs b/socketioxide/src/ns.rs index 3f46686d..84d501a7 100644 --- a/socketioxide/src/ns.rs +++ b/socketioxide/src/ns.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::errors::SendError; +use crate::errors::{AdapterError, SendError}; use crate::{ adapter::{Adapter, LocalAdapter}, errors::Error, @@ -69,14 +69,18 @@ impl Namespace { pub fn disconnect(&self, sid: Sid) -> Result<(), SendError> { if let Some(socket) = self.sockets.write().unwrap().remove(&sid) { - self.adapter.del_all(sid); + self.adapter + .del_all(sid) + .map_err(|err| AdapterError(Box::new(err)))?; socket.send(Packet::disconnect(self.path.clone()))?; } Ok(()) } - fn remove_socket(&self, sid: Sid) { + pub fn remove_socket(&self, sid: Sid) -> Result<(), AdapterError> { self.sockets.write().unwrap().remove(&sid); - self.adapter.del_all(sid); + self.adapter + .del_all(sid) + .map_err(|err| AdapterError(Box::new(err))) } pub fn has(&self, sid: Sid) -> bool { @@ -90,10 +94,9 @@ impl Namespace { pub fn recv(&self, sid: Sid, packet: PacketData) -> Result<(), Error> { match packet { - PacketData::Disconnect => { - self.remove_socket(sid); - Ok(()) - } + PacketData::Disconnect => self + .remove_socket(sid) + .map_err(|err| AdapterError(Box::new(err)).into()), PacketData::Connect(_) => unreachable!("connect packets should be handled before"), PacketData::ConnectError(_) => Ok(()), packet => self.socket_recv(sid, packet), diff --git a/socketioxide/src/operators.rs b/socketioxide/src/operators.rs index 129ed24a..0b781eb4 100644 --- a/socketioxide/src/operators.rs +++ b/socketioxide/src/operators.rs @@ -8,7 +8,7 @@ use serde::{de::DeserializeOwned, Serialize}; use crate::errors::BroadcastError; use crate::{ adapter::{Adapter, BroadcastFlags, BroadcastOptions, Room}, - errors::{AckError, Error}, + errors::AckError, handler::AckResponse, ns::Namespace, packet::Packet, @@ -263,9 +263,9 @@ impl Operators { mut self, event: impl Into, data: impl serde::Serialize, - ) -> Result, AckError>>, Error> { + ) -> Result, AckError>>, BroadcastError> { let packet = self.get_packet(event, data)?; - Ok(self.ns.adapter.broadcast_with_ack(packet, self.opts)) + self.ns.adapter.broadcast_with_ack(packet, self.opts) } /// Get all sockets selected with the previous operators. @@ -278,13 +278,13 @@ impl Operators { /// Namespace::builder().add("/", |socket| async move { /// socket.on("test", |socket, _: (), _, _| async move { /// // Find an extension data in each sockets in the room1 and room3 rooms, except for the room2 - /// let sockets = socket.within("room1").within("room3").except("room2").sockets(); + /// let sockets = socket.within("room1").within("room3").except("room2").sockets().unwrap(); /// for socket in sockets { /// println!("Socket custom string: {:?}", socket.extensions.get::()); /// } /// }); /// }); - pub fn sockets(self) -> Vec>> { + pub fn sockets(self) -> Result>>, A::Error> { self.ns.adapter.fetch_sockets(self.opts) } diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index 838a3fb5..17949e09 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -179,22 +179,22 @@ impl Socket { // Room actions /// Join the given rooms. - pub fn join(&self, rooms: impl RoomParam) { - self.ns.adapter.add_all(self.sid, rooms); + pub fn join(&self, rooms: impl RoomParam) -> Result<(), A::Error> { + self.ns.adapter.add_all(self.sid, rooms) } /// Leave the given rooms. - pub fn leave(&self, rooms: impl RoomParam) { - self.ns.adapter.del(self.sid, rooms); + pub fn leave(&self, rooms: impl RoomParam) -> Result<(), A::Error> { + self.ns.adapter.del(self.sid, rooms) } /// Leave all rooms where the socket is connected. - pub fn leave_all(&self) { - self.ns.adapter.del_all(self.sid); + pub fn leave_all(&self) -> Result<(), A::Error> { + self.ns.adapter.del_all(self.sid) } /// Get all rooms where the socket is connected. - pub fn rooms(&self) -> Vec { + pub fn rooms(&self) -> Result, A::Error> { self.ns.adapter.socket_rooms(self.sid) }