diff --git a/examples/angular-todomvc/src/main.rs b/examples/angular-todomvc/src/main.rs index a5c87fa8..0dcf78c2 100644 --- a/examples/angular-todomvc/src/main.rs +++ b/examples/angular-todomvc/src/main.rs @@ -5,13 +5,14 @@ use socketioxide::{ extract::{Data, SocketRef, State}, SocketIo, }; +use std::sync::Arc; use tower::ServiceBuilder; use tower_http::{cors::CorsLayer, services::ServeDir}; use tracing::info; use tracing_subscriber::FmtSubscriber; -#[derive(Default)] -struct Todos(pub Mutex>); +#[derive(Default, Clone)] +struct Todos(Arc>>); #[derive(Debug, Clone, Serialize, Deserialize)] struct Todo { diff --git a/examples/basic-crud-application/src/handlers/todo.rs b/examples/basic-crud-application/src/handlers/todo.rs index 56fad816..c00b9adb 100644 --- a/examples/basic-crud-application/src/handlers/todo.rs +++ b/examples/basic-crud-application/src/handlers/todo.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::RwLock}; +use std::{collections::HashMap, sync::{RwLock, Arc}}; use serde::{Deserialize, Serialize}; use socketioxide::extract::{AckSender, Data, SocketRef, State}; @@ -21,8 +21,8 @@ pub struct PartialTodo { title: String, } -#[derive(Default)] -pub struct Todos(RwLock>); +#[derive(Clone, Default)] +pub struct Todos(Arc>>); impl Todos { fn insert(&self, id: Uuid, todo: Todo) { self.0.write().unwrap().insert(id, todo); diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index c3571ab0..6417d430 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -9,6 +9,7 @@ use tower::ServiceBuilder; use tower_http::{cors::CorsLayer, services::ServeDir}; use tracing::info; use tracing_subscriber::FmtSubscriber; +use std::sync::Arc; #[derive(Deserialize, Serialize, Debug, Clone)] #[serde(transparent)] @@ -34,11 +35,11 @@ enum Res { username: Username, }, } - -struct UserCnt(AtomicUsize); +#[derive(Clone)] +struct UserCnt(Arc); impl UserCnt { fn new() -> Self { - Self(AtomicUsize::new(0)) + Self(Arc::new(AtomicUsize::new(0))) } fn add_user(&self) -> usize { self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1 diff --git a/examples/loco-rooms-chat/Cargo.toml b/examples/loco-rooms-chat/Cargo.toml index c59da93d..322f9a64 100644 --- a/examples/loco-rooms-chat/Cargo.toml +++ b/examples/loco-rooms-chat/Cargo.toml @@ -6,7 +6,7 @@ rust-version = "1.70" # required by loco [dependencies] -loco-rs = { version = "0.3.1", default-features = false, features = [ +loco-rs = { version = "0.5.0", default-features = false, features = [ "cli", "channels", ] } diff --git a/examples/loco-rooms-chat/src/channels/state.rs b/examples/loco-rooms-chat/src/channels/state.rs index 1db8231c..8d12c84d 100644 --- a/examples/loco-rooms-chat/src/channels/state.rs +++ b/examples/loco-rooms-chat/src/channels/state.rs @@ -1,5 +1,6 @@ use std::collections::{HashMap, VecDeque}; use tokio::sync::RwLock; +use std::sync::Arc; #[derive(serde::Serialize, Clone, Debug)] pub struct Message { @@ -10,9 +11,9 @@ pub struct Message { pub type RoomStore = HashMap>; -#[derive(Default)] +#[derive(Default, Clone)] pub struct MessageStore { - pub messages: RwLock, + pub messages: Arc>, } impl MessageStore { diff --git a/examples/private-messaging/src/store.rs b/examples/private-messaging/src/store.rs index d1553d8d..b0d194b3 100644 --- a/examples/private-messaging/src/store.rs +++ b/examples/private-messaging/src/store.rs @@ -47,8 +47,8 @@ pub struct Message { pub content: String, } -#[derive(Default)] -pub struct Sessions(RwLock>>); +#[derive(Clone, Default)] +pub struct Sessions(Arc>>>); impl Sessions { pub fn get_all_other_sessions(&self, user_id: Uuid) -> Vec> { @@ -69,8 +69,8 @@ impl Sessions { self.0.write().unwrap().insert(session.session_id, session); } } -#[derive(Default)] -pub struct Messages(RwLock>); +#[derive(Clone, Default)] +pub struct Messages(Arc>>); impl Messages { pub fn add(&self, message: Message) { diff --git a/examples/react-rooms-chat/Cargo.toml b/examples/react-rooms-chat/Cargo.toml index c2dfb149..3ec65399 100644 --- a/examples/react-rooms-chat/Cargo.toml +++ b/examples/react-rooms-chat/Cargo.toml @@ -6,13 +6,13 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -socketioxide = { version = "0.9.0", features = ["state"] } +socketioxide = { path = "../../socketioxide", features = ["state"] } tokio = { version = "1", features = ["full"] } tracing = "0.1" tracing-subscriber = "0.3" axum = "0.7.2" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -tower-http = {version = "0.5.0", features = ["cors"]} +tower-http = { version = "0.5.0", features = ["cors"] } tower = "0.4" chrono = { version = "0.4", features = ["serde"] } diff --git a/examples/react-rooms-chat/src/state.rs b/examples/react-rooms-chat/src/state.rs index 1db8231c..4326d1d0 100644 --- a/examples/react-rooms-chat/src/state.rs +++ b/examples/react-rooms-chat/src/state.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, VecDeque}; +use std::{collections::{HashMap, VecDeque}, sync::Arc}; use tokio::sync::RwLock; #[derive(serde::Serialize, Clone, Debug)] @@ -10,9 +10,9 @@ pub struct Message { pub type RoomStore = HashMap>; -#[derive(Default)] +#[derive(Clone, Default)] pub struct MessageStore { - pub messages: RwLock, + pub messages: Arc>, } impl MessageStore { diff --git a/socketioxide/src/adapter.rs b/socketioxide/src/adapter.rs index 0e875b19..58850c7d 100644 --- a/socketioxide/src/adapter.rs +++ b/socketioxide/src/adapter.rs @@ -1,4 +1,4 @@ -//! Adapters are responsible for managing the state of the server. +//! Adapters are responsible for managing the internal state of the server (rooms, sockets, etc...). //! When a socket joins or leaves a room, the adapter is responsible for updating the state. //! The default adapter is the [`LocalAdapter`], which stores the state in memory. //! Other adapters can be made to share the state between multiple servers. diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index c78bd098..4a3e99d2 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -22,20 +22,26 @@ use crate::{ }; use crate::{ProtocolVersion, SocketIo}; -#[derive(Debug)] pub struct Client { pub(crate) config: Arc, ns: RwLock, Arc>>>, + #[cfg(feature = "state")] + pub(crate) state: state::TypeMap![Send + Sync], } impl Client { - pub fn new(config: Arc) -> Self { + pub fn new( + config: Arc, + #[cfg(feature = "state")] mut state: state::TypeMap![Send + Sync], + ) -> Self { #[cfg(feature = "state")] - crate::state::freeze_state(); + state.freeze(); Self { config, ns: RwLock::new(HashMap::new()), + #[cfg(feature = "state")] + state, } } @@ -346,6 +352,15 @@ impl EngineIoHandler for Client { } } } +impl std::fmt::Debug for Client { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut f = f.debug_struct("Client"); + f.field("config", &self.config).field("ns", &self.ns); + #[cfg(feature = "state")] + let f = f.field("state", &self.state); + f.finish() + } +} /// Utility that applies an incoming binary payload to a partial binary packet /// waiting to be filled with all the payloads @@ -382,7 +397,11 @@ mod test { connect_timeout: CONNECT_TIMEOUT, ..Default::default() }; - let client = Client::::new(std::sync::Arc::new(config)); + let client = Client::::new( + std::sync::Arc::new(config), + #[cfg(feature = "state")] + Default::default(), + ); client.add_ns("/".into(), || {}); Arc::new(client) } diff --git a/socketioxide/src/extract/mod.rs b/socketioxide/src/extract/mod.rs index 9b84fd99..b6f7196d 100644 --- a/socketioxide/src/extract/mod.rs +++ b/socketioxide/src/extract/mod.rs @@ -16,7 +16,7 @@ //! * [`ProtocolVersion`](crate::ProtocolVersion): extracts the protocol version //! * [`TransportType`](crate::TransportType): extracts the transport type //! * [`DisconnectReason`](crate::socket::DisconnectReason): extracts the reason of the disconnection -//! * [`State`]: extracts a reference to a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). +//! * [`State`]: extracts a [`Clone`] of a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). //! * [`Extension`]: extracts an extension of the given type stored on the called socket by cloning it. //! * [`MaybeExtension`]: extracts an extension of the given type if it exists or [`None`] otherwise //! * [`HttpExtension`]: extracts an http extension of the given type coming from the request. diff --git a/socketioxide/src/extract/state.rs b/socketioxide/src/extract/state.rs index 4c7f67ad..50886aec 100644 --- a/socketioxide/src/extract/state.rs +++ b/socketioxide/src/extract/state.rs @@ -1,14 +1,12 @@ use bytes::Bytes; -use crate::state::get_state; -use std::ops::Deref; use std::sync::Arc; use crate::adapter::Adapter; use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; use crate::socket::{DisconnectReason, Socket}; -/// An Extractor that contains a reference to a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). +/// An Extractor that contains a [`Clone`] of a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). /// It implements [`std::ops::Deref`] to access the inner type so you can use it as a normal reference. /// /// The specified state type must be the same as the one set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). @@ -20,11 +18,10 @@ use crate::socket::{DisconnectReason, Socket}; /// ``` /// # use socketioxide::{SocketIo, extract::{SocketRef, State}}; /// # use serde::{Serialize, Deserialize}; -/// # use std::sync::atomic::AtomicUsize; -/// # use std::sync::atomic::Ordering; -/// #[derive(Default)] +/// # use std::sync::{Arc, atomic::{Ordering, AtomicUsize}}; +/// #[derive(Default, Clone)] /// struct MyAppData { -/// user_cnt: AtomicUsize, +/// user_cnt: Arc, /// } /// impl MyAppData { /// fn add_user(&self) { @@ -39,7 +36,7 @@ use crate::socket::{DisconnectReason, Socket}; /// state.add_user(); /// println!("User count: {}", state.user_cnt.load(Ordering::SeqCst)); /// }); -pub struct State(pub &'static T); +pub struct State(pub T); /// It was impossible to find the given state and therefore the handler won't be called. pub struct StateNotFound(std::marker::PhantomData); @@ -48,7 +45,7 @@ impl std::fmt::Display for StateNotFound { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "State of type {} not found, maybe you forgot to insert it in the extensions map?", + "State of type {} not found, maybe you forgot to insert it in the state map?", std::any::type_name::() ) } @@ -60,46 +57,43 @@ impl std::fmt::Debug for StateNotFound { } impl std::error::Error for StateNotFound {} -impl FromConnectParts for State { +impl FromConnectParts for State { type Error = StateNotFound; fn from_connect_parts( - _: &Arc>, + s: &Arc>, _: &Option, ) -> Result> { - get_state::() + s.get_io() + .get_state::() .map(State) .ok_or(StateNotFound(std::marker::PhantomData)) } } -impl FromDisconnectParts for State { +impl FromDisconnectParts for State { type Error = StateNotFound; fn from_disconnect_parts( - _: &Arc>, + s: &Arc>, _: DisconnectReason, ) -> Result> { - get_state::() + s.get_io() + .get_state::() .map(State) .ok_or(StateNotFound(std::marker::PhantomData)) } } -impl FromMessageParts for State { +impl FromMessageParts for State { type Error = StateNotFound; fn from_message_parts( - _: &Arc>, + s: &Arc>, _: &mut serde_json::Value, _: &mut Vec, _: &Option, ) -> Result> { - get_state::() + s.get_io() + .get_state::() .map(State) .ok_or(StateNotFound(std::marker::PhantomData)) } } -impl Deref for State { - type Target = &'static T; - #[inline(always)] - fn deref(&self) -> &Self::Target { - &self.0 - } -} +super::__impl_deref!(State); diff --git a/socketioxide/src/io.rs b/socketioxide/src/io.rs index db6a4b72..5e6c8678 100644 --- a/socketioxide/src/io.rs +++ b/socketioxide/src/io.rs @@ -57,6 +57,8 @@ pub struct SocketIoBuilder { config: SocketIoConfig, engine_config_builder: EngineIoConfigBuilder, adapter: std::marker::PhantomData, + #[cfg(feature = "state")] + state: state::TypeMap![Send + Sync], } impl SocketIoBuilder { @@ -66,6 +68,8 @@ impl SocketIoBuilder { config: SocketIoConfig::default(), engine_config_builder: EngineIoConfigBuilder::new().req_path("/socket.io".to_string()), adapter: std::marker::PhantomData, + #[cfg(feature = "state")] + state: std::default::Default::default(), } } @@ -159,17 +163,20 @@ impl SocketIoBuilder { config: self.config, engine_config_builder: self.engine_config_builder, adapter: std::marker::PhantomData, + #[cfg(feature = "state")] + state: self.state, } } /// Add a custom global state for the [`SocketIo`] instance. /// This state will be accessible from every handler with the [`State`](crate::extract::State) extractor. /// You can set any number of states as long as they have different types. + /// The state must be cloneable, therefore it is recommended to wrap it in an `Arc` if you want shared state. #[inline] #[cfg_attr(docsrs, doc(cfg(feature = "state")))] #[cfg(feature = "state")] - pub fn with_state(self, state: S) -> Self { - crate::state::set_state(state); + pub fn with_state(self, state: S) -> Self { + self.state.set(state); self } @@ -179,7 +186,11 @@ impl SocketIoBuilder { pub fn build_layer(mut self) -> (SocketIoLayer, SocketIo) { self.config.engine_config = self.engine_config_builder.build(); - let (layer, client) = SocketIoLayer::from_config(Arc::new(self.config)); + let (layer, client) = SocketIoLayer::from_config( + Arc::new(self.config), + #[cfg(feature = "state")] + self.state, + ); (layer, SocketIo(client)) } @@ -190,8 +201,12 @@ impl SocketIoBuilder { pub fn build_svc(mut self) -> (SocketIoService, SocketIo) { self.config.engine_config = self.engine_config_builder.build(); - let (svc, client) = - SocketIoService::with_config_inner(NotFoundService, Arc::new(self.config)); + let (svc, client) = SocketIoService::with_config_inner( + NotFoundService, + Arc::new(self.config), + #[cfg(feature = "state")] + self.state, + ); (svc, SocketIo(client)) } @@ -201,7 +216,12 @@ impl SocketIoBuilder { pub fn build_with_inner_svc(mut self, svc: S) -> (SocketIoService, SocketIo) { self.config.engine_config = self.engine_config_builder.build(); - let (svc, client) = SocketIoService::with_config_inner(svc, Arc::new(self.config)); + let (svc, client) = SocketIoService::with_config_inner( + svc, + Arc::new(self.config), + #[cfg(feature = "state")] + self.state, + ); (svc, SocketIo(client)) } } @@ -791,6 +811,11 @@ impl SocketIo { self.get_default_op().get_socket(sid) } + #[cfg(feature = "state")] + pub(crate) fn get_state(&self) -> Option { + self.0.state.try_get::().cloned() + } + /// Returns a new operator on the given namespace #[inline(always)] fn get_op(&self, path: &str) -> Option> { diff --git a/socketioxide/src/layer.rs b/socketioxide/src/layer.rs index 1a656766..d463e73c 100644 --- a/socketioxide/src/layer.rs +++ b/socketioxide/src/layer.rs @@ -41,8 +41,15 @@ impl Clone for SocketIoLayer { } impl SocketIoLayer { - pub(crate) fn from_config(config: Arc) -> (Self, Arc>) { - let client = Arc::new(Client::new(config.clone())); + pub(crate) fn from_config( + config: Arc, + #[cfg(feature = "state")] state: state::TypeMap![Send + Sync], + ) -> (Self, Arc>) { + let client = Arc::new(Client::new( + config.clone(), + #[cfg(feature = "state")] + state, + )); let layer = Self { client: client.clone(), }; diff --git a/socketioxide/src/lib.rs b/socketioxide/src/lib.rs index c02fb2dd..d685c2fa 100644 --- a/socketioxide/src/lib.rs +++ b/socketioxide/src/lib.rs @@ -158,7 +158,7 @@ //! * [`ProtocolVersion`]: extracts the protocol version of the socket //! * [`TransportType`]: extracts the transport type of the socket //! * [`DisconnectReason`](crate::socket::DisconnectReason): extracts the reason of the disconnection -//! * [`State`](extract::State): extracts a reference to a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). +//! * [`State`](extract::State): extracts a [`Clone`] of a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). //! * [`Extension`](extract::Extension): extracts a clone of the corresponding socket extension //! * [`MaybeExtension`](extract::MaybeExtension): extracts a clone of the corresponding socket extension if it exists //! * [`HttpExtension`](extract::HttpExtension): extracts a clone of the http request extension @@ -255,19 +255,18 @@ //! You can enable the `extensions` feature and use the [`extensions`](socket::Socket::extensions) field on any socket to manage //! the state of each socket. It is backed by a [`RwLock>`](std::sync::RwLock) so you can safely access it //! from multiple threads. However, the value must be [`Clone`] and `'static`. -//! When calling get, or using the [`Extension`](extract::Extension) extractor, the value will always be cloned. +//! When calling get, or using the [`Extension`](extract::Extension)/[`MaybeExtension`](extract::MaybeExtension) extractor, +//! the value will always be cloned. //! See the [`extensions`] module doc for more details. //! //! #### Global state //! You can enable the `state` feature and use [`SocketIoBuilder::with_state`](SocketIoBuilder) method to set //! multiple global states for the server. You can then access them from any handler with the [`State`](extract::State) extractor. //! -//! Because the global state is staticaly defined, beware that the state map will exist for the whole lifetime of the program even -//! if you drop everything and close you socket.io server. This is a limitation because of the impossibility to have extractors with lifetimes, -//! therefore state references must be `'static`. +//! The state is stored in the [`SocketIo`] handle and is shared between all the sockets. The only limitation is that all the provided state types must be clonable. +//! Therefore it is recommended to use the [`Arc`](std::sync::Arc) type to share the state between the handlers. //! -//! Another limitation is that because it is common to the whole server. If you build a second server, it will share the same state. -//! Also if the first server is already started you won't be able to add new states because states are frozen at the start of the first server. +//! You can then use the [`State`](extract::State) extractor to access the state in the handlers. //! //! ## Adapters //! This library is designed to work with clustering. It uses the [`Adapter`](adapter::Adapter) trait to abstract the underlying storage. @@ -285,8 +284,6 @@ pub mod adapter; #[cfg_attr(docsrs, doc(cfg(feature = "extensions")))] #[cfg(feature = "extensions")] pub mod extensions; -#[cfg(feature = "state")] -mod state; pub mod ack; pub mod extract; diff --git a/socketioxide/src/service.rs b/socketioxide/src/service.rs index 45c26731..fd6ea517 100644 --- a/socketioxide/src/service.rs +++ b/socketioxide/src/service.rs @@ -120,9 +120,14 @@ impl SocketIoService { pub(crate) fn with_config_inner( inner: S, config: Arc, + #[cfg(feature = "state")] state: state::TypeMap![Send + Sync], ) -> (Self, Arc>) { let engine_config = config.engine_config.clone(); - let client = Arc::new(Client::new(config)); + let client = Arc::new(Client::new( + config, + #[cfg(feature = "state")] + state, + )); let svc = EngineIoService::with_config_inner(inner, client.clone(), engine_config); (Self { engine_svc: svc }, client) } diff --git a/socketioxide/src/state.rs b/socketioxide/src/state.rs deleted file mode 100644 index a76015d1..00000000 --- a/socketioxide/src/state.rs +++ /dev/null @@ -1,41 +0,0 @@ -//! An optional global state for the server which is backed by a [`TypeMap`]. -//! The state is global and not part of the socketio instance because it is impossible -//! to have extractors with lifetimes. Therefore the only way to propagate a `State` reference -//! to a handler parameter is to make is `'static`. -use state::TypeMap; -static mut STATE: TypeMap![Send + Sync] = ::new(); - -// SAFETY: even if the state is mut and therefeore not thread safe, -// each of the following functions are called disctincly in different -// step of the server lifecycle. - -// SAFETY: the `get_state` is called many times from extractors but never mutably and after the `freeze_state` call -pub(crate) fn get_state() -> Option<&'static T> { - unsafe { STATE.try_get::() } -} -// SAFETY: the `freeze_state` is only called once at the launch of the server -pub(crate) fn freeze_state() { - unsafe { STATE.freeze() } -} -// SAFETY: the `set_state` is only called when building the server and can be called multiple times -pub(crate) fn set_state(value: T) { - unsafe { STATE.set(value) }; -} - -mod test { - #[test] - fn state_lifecycle() { - use super::TypeMap; - // Mimic the parent static state - static mut STATE: TypeMap![Send + Sync] = ::new(); - - unsafe { STATE.set(1i32) }; - unsafe { STATE.set("hello") }; - - assert_eq!(unsafe { STATE.len() }, 2); - unsafe { STATE.freeze() } - - assert_eq!(unsafe { STATE.try_get::() }, Some(&1)); - assert_eq!(unsafe { STATE.try_get::<&str>() }, Some(&"hello")); - } -} diff --git a/socketioxide/tests/connect.rs b/socketioxide/tests/connect.rs index 8767dc62..ba77d21b 100644 --- a/socketioxide/tests/connect.rs +++ b/socketioxide/tests/connect.rs @@ -12,7 +12,7 @@ fn create_msg(ns: &str, event: &str, data: impl Into) -> engi Message(packet.into()) } async fn timeout_rcv(srx: &mut tokio::sync::mpsc::Receiver) -> T { - tokio::time::timeout(std::time::Duration::from_millis(500), srx.recv()) + tokio::time::timeout(std::time::Duration::from_millis(10), srx.recv()) .await .unwrap() .unwrap() diff --git a/socketioxide/tests/extractors.rs b/socketioxide/tests/extractors.rs index 831771a3..e9b38116 100644 --- a/socketioxide/tests/extractors.rs +++ b/socketioxide/tests/extractors.rs @@ -14,13 +14,13 @@ mod fixture; mod utils; async fn timeout_rcv(srx: &mut tokio::sync::mpsc::Receiver) -> T { - tokio::time::timeout(Duration::from_millis(500), srx.recv()) + tokio::time::timeout(Duration::from_millis(10), srx.recv()) .await .unwrap() .unwrap() } async fn timeout_rcv_err(srx: &mut tokio::sync::mpsc::Receiver) { - tokio::time::timeout(Duration::from_millis(500), srx.recv()) + tokio::time::timeout(Duration::from_millis(10), srx.recv()) .await .unwrap_err(); }