Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: deadlocking when calling close_ns from inside a disconnect_handler #316

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use bytes::Bytes;
use engineioxide::handler::EngineIoHandler;
use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Socket as EIoSocket};
use engineioxide::Str;
use futures_util::TryFutureExt;
use futures_util::{FutureExt, TryFutureExt};

use engineioxide::sid::Sid;
use tokio::sync::oneshot;

use crate::adapter::Adapter;
use crate::handler::ConnectHandler;
use crate::socket::DisconnectReason;
use crate::ProtocolVersion;
use crate::{
errors::Error,
Expand Down Expand Up @@ -121,11 +122,19 @@ impl<A: Adapter> Client<A> {
self.ns.write().unwrap().insert(path, ns);
}

/// Deletes a namespace handler
/// Deletes a namespace handler and closes all the connections to it
pub fn delete_ns(&self, path: &str) {
#[cfg(feature = "v4")]
if path == "/" {
panic!("the root namespace \"/\" cannot be deleted for the socket.io v4 protocol. See https://socket.io/docs/v3/namespaces/#main-namespace for more info");
}

#[cfg(feature = "tracing")]
tracing::debug!("deleting namespace {}", path);
self.ns.write().unwrap().remove(path);
if let Some(ns) = self.ns.write().unwrap().remove(path) {
ns.close(DisconnectReason::ServerNSDisconnect)
.now_or_never();
}
}

pub fn get_ns(&self, path: &str) -> Option<Arc<Namespace<A>>> {
Expand All @@ -138,7 +147,11 @@ impl<A: Adapter> Client<A> {
#[cfg(feature = "tracing")]
tracing::debug!("closing all namespaces");
let ns = self.ns.read().unwrap().clone();
futures_util::future::join_all(ns.values().map(|ns| ns.close())).await;
futures_util::future::join_all(
ns.values()
.map(|ns| ns.close(DisconnectReason::ClosingServer)),
)
.await;
#[cfg(feature = "tracing")]
tracing::debug!("all namespaces closed");
}
Expand Down Expand Up @@ -230,12 +243,16 @@ impl<A: Adapter> EngineIoHandler for Client<A> {
fn on_disconnect(&self, socket: Arc<EIoSocket<SocketData>>, reason: EIoDisconnectReason) {
#[cfg(feature = "tracing")]
tracing::debug!("eio socket disconnected");
let _res: Result<Vec<_>, _> = self
let socks: Vec<_> = self
.ns
.read()
.unwrap()
.values()
.filter_map(|ns| ns.get_socket(socket.id).ok())
.collect();

let _res: Result<Vec<_>, _> = socks
.into_iter()
.map(|s| s.close(reason.clone().into()))
.collect();

Expand Down
9 changes: 8 additions & 1 deletion socketioxide/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,14 @@ impl<A: Adapter> SocketIo<A> {
self.0.add_ns(path.into(), callback);
}

/// Deletes the namespace with the given path
/// Deletes the namespace with the given path.
///
/// This will disconnect all sockets connected to this
/// namespace in a deferred way.
///
/// # Panics
/// If the v4 protocol (legacy) is enabled and the namespace to delete is the default namespace "/".
/// For v4, the default namespace cannot be deleted. See [official doc](https://socket.io/docs/v3/namespaces/#main-namespace) for more informations.
#[inline]
pub fn delete_ns<'a>(&self, path: impl Into<&'a str>) {
self.0.delete_ns(path.into());
Expand Down
53 changes: 43 additions & 10 deletions socketioxide/src/ns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
errors::{ConnectFail, Error},
handler::{BoxedConnectHandler, ConnectHandler, MakeErasedHandler},
packet::{Packet, PacketData},
socket::Socket,
socket::{DisconnectReason, Socket},
SocketIoConfig,
};
use crate::{client::SocketData, errors::AdapterError};
Expand Down Expand Up @@ -85,6 +85,9 @@ impl<A: Adapter> Namespace<A> {

/// Removes a socket from a namespace and propagate the event to the adapter
pub fn remove_socket(&self, sid: Sid) -> Result<(), AdapterError> {
#[cfg(feature = "tracing")]
tracing::trace!(?sid, "removing socket from namespace");

self.sockets.write().unwrap().remove(&sid);
self.adapter
.del_all(sid)
Expand Down Expand Up @@ -118,18 +121,40 @@ impl<A: Adapter> Namespace<A> {

/// Closes the entire namespace :
/// * Closes the adapter
/// * Closes all the sockets and their underlying connections
/// * Closes all the sockets and
/// their underlying connections in case of [`DisconnectReason::ClosingServer`]
/// * Removes all the sockets from the namespace
pub async fn close(&self) {
self.adapter.close().ok();
#[cfg(feature = "tracing")]
tracing::debug!("closing all sockets in namespace {}", self.path);
///
/// This function is using .await points only when called with [`DisconnectReason::ClosingServer`]
pub async fn close(&self, reason: DisconnectReason) {
use futures_util::future;
let sockets = self.sockets.read().unwrap().clone();
futures_util::future::join_all(sockets.values().map(|s| s.close_underlying_transport()))
.await;
self.sockets.write().unwrap().shrink_to_fit();

#[cfg(feature = "tracing")]
tracing::debug!(?self.path, "closing {} sockets in namespace", sockets.len());

if reason == DisconnectReason::ClosingServer {
// When closing the underlying transport, this will indirectly close the socket
// Therefore there is no need to manually call `s.close()`.
future::join_all(sockets.values().map(|s| s.close_underlying_transport())).await;
} else {
for s in sockets.into_values() {
let _sid = s.id;
let _err = s.close(reason);
#[cfg(feature = "tracing")]
if let Err(err) = _err {
tracing::debug!(?_sid, ?err, "error closing socket");
}
}
}
#[cfg(feature = "tracing")]
tracing::debug!(?self.path, "all sockets in namespace closed");

let _err = self.adapter.close();
#[cfg(feature = "tracing")]
tracing::debug!("all sockets in namespace {} closed", self.path);
if let Err(err) = _err {
tracing::debug!(?err, "could not close adapter");
}
}
}

Expand Down Expand Up @@ -160,3 +185,11 @@ impl<A: Adapter + std::fmt::Debug> std::fmt::Debug for Namespace<A> {
.finish()
}
}

#[cfg(feature = "tracing")]
impl<A: Adapter> Drop for Namespace<A> {
fn drop(&mut self) {
#[cfg(feature = "tracing")]
tracing::debug!("dropping namespace {}", self.path);
}
}
8 changes: 6 additions & 2 deletions socketioxide/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub enum DisconnectReason {
/// The client has manually disconnected the socket using [`socket.disconnect()`](https://socket.io/fr/docs/v4/client-api/#socketdisconnect)
ClientNSDisconnect,

/// The socket was forcefully disconnected from the namespace with [`Socket::disconnect`]
/// The socket was forcefully disconnected from the namespace with [`Socket::disconnect`] or with [`SocketIo::delete_ns`](crate::io::SocketIo::delete_ns)
ServerNSDisconnect,

/// The server is being closed
Expand Down Expand Up @@ -694,7 +694,11 @@ impl<A: Adapter> Socket<A> {
pub(crate) fn close(self: Arc<Self>, reason: DisconnectReason) -> Result<(), AdapterError> {
self.set_connected(false);

if let Some(handler) = self.disconnect_handler.lock().unwrap().take() {
let handler = { self.disconnect_handler.lock().unwrap().take() };
if let Some(handler) = handler {
#[cfg(feature = "tracing")]
tracing::trace!(?reason, ?self.id, "spawning disconnect handler");

handler.call(self.clone(), reason);
}

Expand Down
108 changes: 107 additions & 1 deletion socketioxide/tests/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@ mod utils;

use bytes::Bytes;
use engineioxide::Packet::*;
use socketioxide::{extract::SocketRef, handler::ConnectHandler, SendError, SocketError, SocketIo};
use socketioxide::{
extract::SocketRef, handler::ConnectHandler, packet::Packet, SendError, SocketError, SocketIo,
};
use tokio::sync::mpsc;

fn create_msg(ns: &str, event: &str, data: impl Into<serde_json::Value>) -> engineioxide::Packet {
let packet: String = Packet::event(ns, event, data.into()).into();
Message(packet.into())
}
async fn timeout_rcv<T: std::fmt::Debug>(srx: &mut tokio::sync::mpsc::Receiver<T>) -> T {
tokio::time::timeout(std::time::Duration::from_millis(500), srx.recv())
.await
.unwrap()
.unwrap()
}

#[tokio::test]
pub async fn connect_middleware() {
let (_svc, io) = SocketIo::new_svc();
Expand Down Expand Up @@ -97,3 +110,96 @@ pub async fn connect_middleware_error() {
rx.recv().await.unwrap();
assert_err!(rx.try_recv());
}

#[tokio::test]
async fn remove_ns_from_connect_handler() {
let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(2);
let (_svc, io) = SocketIo::new_svc();

let io_clone = io.clone();
io.ns("/test1", move || {
tx.try_send(()).unwrap();
io_clone.delete_ns("/test1");
});

let (stx, mut srx) = io.new_dummy_sock("/test1", ()).await;
timeout_rcv(&mut srx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
timeout_rcv(&mut rx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
// No response since ns is already deleted
let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await;
assert!(elapsed.is_err() || elapsed.unwrap().is_none());
}

#[tokio::test]
async fn remove_ns_from_middleware() {
let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(2);
let (_svc, io) = SocketIo::new_svc();

let io_clone = io.clone();
let middleware = move || {
tx.try_send(()).unwrap();
io_clone.delete_ns("/test1");
Ok::<(), std::convert::Infallible>(())
};
fn handler() {}
io.ns("/test1", handler.with(middleware));

let (stx, mut srx) = io.new_dummy_sock("/test1", ()).await;
timeout_rcv(&mut srx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
timeout_rcv(&mut rx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
// No response since ns is already deleted
let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await;
assert!(elapsed.is_err() || elapsed.unwrap().is_none());
}

#[tokio::test]
async fn remove_ns_from_event_handler() {
let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(2);
let (_svc, io) = SocketIo::new_svc();

let io_clone = io.clone();
io.ns("/test1", move |s: SocketRef| {
s.on("delete_ns", move || {
io_clone.delete_ns("/test1");
tx.try_send(()).unwrap();
});
});

let (stx, mut srx) = io.new_dummy_sock("/test1", ()).await;
timeout_rcv(&mut srx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
timeout_rcv(&mut rx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
// No response since ns is already deleted
let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await;
assert!(elapsed.is_err() || elapsed.unwrap().is_none());
}

#[tokio::test]
async fn remove_ns_from_disconnect_handler() {
let (tx, mut rx) = tokio::sync::mpsc::channel::<&'static str>(2);
let (_svc, io) = SocketIo::new_svc();

let io_clone = io.clone();
io.ns("/test2", move |s: SocketRef| {
tx.try_send("connect").unwrap();
s.on_disconnect(move || {
io_clone.delete_ns("/test2");
tx.try_send("disconnect").unwrap();
})
});

let (stx, mut srx) = io.new_dummy_sock("/test2", ()).await;
assert_eq!(timeout_rcv(&mut rx).await, "connect");
timeout_rcv(&mut srx).await;
assert_ok!(stx.try_send(Close));
assert_eq!(timeout_rcv(&mut rx).await, "disconnect");

let (_stx, mut _srx) = io.new_dummy_sock("/test2", ()).await;
let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await;
assert!(elapsed.is_err() || elapsed.unwrap().is_none());
}
21 changes: 21 additions & 0 deletions socketioxide/tests/disconnect_reason.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,27 @@ pub async fn server_ns_disconnect() {
assert_eq!(data, DisconnectReason::ServerNSDisconnect);
}

#[tokio::test]
pub async fn server_ns_close() {
let (tx, mut rx) = mpsc::channel::<DisconnectReason>(1);
let io = create_server(12353).await;
let io2 = io.clone();
io.ns("/test", move |socket: SocketRef| {
socket.on_disconnect(move |reason: DisconnectReason| tx.try_send(reason).unwrap());
io2.delete_ns("/test");
});

let mut ws = create_ws_connection(12353).await;
ws.send(Message::Text("40/test,{}".to_string()))
.await
.unwrap();
let data = tokio::time::timeout(Duration::from_millis(20), rx.recv())
.await
.expect("timeout waiting for DisconnectReason::ServerNSDisconnect")
.unwrap();
assert_eq!(data, DisconnectReason::ServerNSDisconnect);
}

#[tokio::test]
pub async fn server_ws_closing() {
let io = create_server(12350).await;
Expand Down