Skip to content

Commit

Permalink
feat(socketio/extensions): rework extensions and add an Extension d…
Browse files Browse the repository at this point in the history
…ecorator and an `HttpExtension` decorator (#309)

* feat(socketio/extensions): use `RwLock<HashMap>` rather than `DashMap`

* chore(bench): add bencher ci

* fix: socketioxide benches with `Bytes`

* chore(bench): fix ci name

* chore(bench): add RUSTFLAG for testing

* fix: engineioxide benches

* chore(bench): remove matrix test

* chore(bench): add groups

* chore(bench): improve extensions bench

* feat(socketio/extract): refactor extract mod

* feat(socketio/extract): add `(Maybe)(Http)Extension` extractors

* docs(example): update examples with `Extension` extractor

* test(socketio/extract): add tests for `Extension` and `MaybeExtension`

* docs(example) fmt chat example

* test(socketio): fix extractors test

* doc(socketio): improve doc for socketioxide

* test(socketio): increase timeout

* doc(socketio): improve doc
  • Loading branch information
Totodore authored Jun 4, 2024
1 parent 041109f commit f509ba7
Show file tree
Hide file tree
Showing 17 changed files with 1,094 additions and 691 deletions.
47 changes: 25 additions & 22 deletions examples/chat/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::atomic::AtomicUsize;

use serde::{Deserialize, Serialize};
use socketioxide::{
extract::{Data, SocketRef, State},
extract::{Data, Extension, SocketRef, State},
SocketIo,
};
use tower::ServiceBuilder;
Expand Down Expand Up @@ -59,14 +59,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let (layer, io) = SocketIo::builder().with_state(UserCnt::new()).build_layer();

io.ns("/", |s: SocketRef| {
s.on("new message", |s: SocketRef, Data::<String>(msg)| {
let username = s.extensions.get::<Username>().unwrap().clone();
let msg = Res::Message {
username,
message: msg,
};
s.broadcast().emit("new message", msg).ok();
});
s.on(
"new message",
|s: SocketRef, Data::<String>(msg), Extension::<Username>(username)| {
let msg = Res::Message {
username,
message: msg,
};
s.broadcast().emit("new message", msg).ok();
},
);

s.on(
"add user",
Expand All @@ -86,30 +88,31 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
},
);

s.on("typing", |s: SocketRef| {
let username = s.extensions.get::<Username>().unwrap().clone();
s.on("typing", |s: SocketRef, Extension::<Username>(username)| {
s.broadcast()
.emit("typing", Res::Username { username })
.ok();
});

s.on("stop typing", |s: SocketRef| {
let username = s.extensions.get::<Username>().unwrap().clone();
s.broadcast()
.emit("stop typing", Res::Username { username })
.ok();
});
s.on(
"stop typing",
|s: SocketRef, Extension::<Username>(username)| {
s.broadcast()
.emit("stop typing", Res::Username { username })
.ok();
},
);

s.on_disconnect(|s: SocketRef, user_cnt: State<UserCnt>| {
if let Some(username) = s.extensions.get::<Username>() {
s.on_disconnect(
|s: SocketRef, user_cnt: State<UserCnt>, Extension::<Username>(username)| {
let num_users = user_cnt.remove_user();
let res = Res::UserEvent {
num_users,
username: username.clone(),
username,
};
s.broadcast().emit("user left", res).ok();
}
});
},
);
});

let app = axum::Router::new()
Expand Down
1 change: 1 addition & 0 deletions examples/private-messaging/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ edition = "2021"
socketioxide = { path = "../../socketioxide", features = [
"extensions",
"state",
"tracing",
] }
axum.workspace = true
tokio = { workspace = true, features = ["rt-multi-thread", "macros"] }
Expand Down
119 changes: 61 additions & 58 deletions examples/private-messaging/src/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::sync::{atomic::Ordering, Arc};

use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use socketioxide::extract::{Data, SocketRef, State, TryData};
use socketioxide::extract::{Data, Extension, SocketRef, State};
use uuid::Uuid;

use crate::store::{Message, Messages, Session, Sessions};
Expand All @@ -22,12 +24,19 @@ struct UserConnectedRes {
messages: Vec<Message>,
}

#[derive(Debug, Serialize, Clone)]
struct UserDisconnectedRes {
#[serde(rename = "userID")]
user_id: Uuid,
username: String,
}

impl UserConnectedRes {
fn new(session: &Session, messages: Vec<Message>) -> Self {
Self {
user_id: session.user_id,
username: session.username.clone(),
connected: session.connected,
connected: session.connected.load(Ordering::SeqCst),
messages,
}
}
Expand All @@ -38,86 +47,80 @@ struct PrivateMessageReq {
content: String,
}

pub fn on_connection(s: SocketRef) {
pub fn on_connection(
s: SocketRef,
Extension::<Arc<Session>>(session): Extension<Arc<Session>>,
State(sessions): State<Sessions>,
State(msgs): State<Messages>,
) {
s.emit("session", (*session).clone()).unwrap();

let users = sessions
.get_all_other_sessions(session.user_id)
.into_iter()
.map(|session| {
let messages = msgs.get_all_for_user(session.user_id);
UserConnectedRes::new(&session, messages)
})
.collect::<Vec<_>>();

s.emit("users", [users]).unwrap();

let res = UserConnectedRes::new(&session, vec![]);
s.broadcast().emit("user connected", res).unwrap();

s.on(
"private message",
|s: SocketRef, Data(PrivateMessageReq { to, content }), State(Messages(msg))| {
let user_id = s.extensions.get::<Session>().unwrap().user_id;
|s: SocketRef,
Data(PrivateMessageReq { to, content }),
State::<Messages>(msgs),
Extension::<Arc<Session>>(session)| {
let message = Message {
from: user_id,
from: session.user_id,
to,
content,
};
msg.write().unwrap().push(message.clone());
msgs.add(message.clone());
s.within(to.to_string())
.emit("private message", message)
.ok();
},
);

s.on_disconnect(|s: SocketRef, State(Sessions(sessions))| {
let mut session = s.extensions.get::<Session>().unwrap().clone();
session.connected = false;

sessions
.write()
.unwrap()
.get_mut(&session.session_id)
.unwrap()
.connected = false;

s.broadcast().emit("user disconnected", session).ok();
s.on_disconnect(|s: SocketRef, Extension::<Arc<Session>>(session)| {
session.set_connected(false);
s.broadcast()
.emit(
"user disconnected",
UserDisconnectedRes {
user_id: session.user_id,
username: session.username.clone(),
},
)
.ok();
});
}

/// Handles the connection of a new user
/// Handles the connection of a new user.
/// Be careful to not emit anything to the user before the authentication is done.
pub fn authenticate_middleware(
s: SocketRef,
TryData(auth): TryData<Auth>,
State(Sessions(session_state)): State<Sessions>,
State(Messages(msg_state)): State<Messages>,
Data(auth): Data<Auth>,
State(sessions): State<Sessions>,
) -> Result<(), anyhow::Error> {
let auth = auth?;
let mut sessions = session_state.write().unwrap();
if let Some(session) = auth.session_id.and_then(|id| sessions.get_mut(&id)) {
session.connected = true;
let session = if let Some(session) = auth.session_id.and_then(|id| sessions.get(id)) {
session.set_connected(true);
s.extensions.insert(session.clone());
session
} else {
let username = auth.username.ok_or(anyhow!("invalid username"))?;
let session = Session::new(username);
let session = Arc::new(Session::new(username));
s.extensions.insert(session.clone());

sessions.insert(session.session_id, session);
sessions.add(session.clone());
session
};
drop(sessions);

let session = s.extensions.get::<Session>().unwrap();

s.join(session.user_id.to_string()).ok();
s.emit("session", session.clone())?;

let users = session_state
.read()
.unwrap()
.iter()
.filter(|(id, _)| id != &&session.session_id)
.map(|(_, session)| {
let messages = msg_state
.read()
.unwrap()
.iter()
.filter(|message| message.to == session.user_id || message.from == session.user_id)
.cloned()
.collect();

UserConnectedRes::new(session, messages)
})
.collect::<Vec<_>>();

s.emit("users", [users])?;

let res = UserConnectedRes::new(&session, vec![]);
s.join(session.user_id.to_string())?;

s.broadcast().emit("user connected", res)?;
Ok(())
}
67 changes: 61 additions & 6 deletions examples/private-messaging/src/store.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,42 @@
use std::{collections::HashMap, sync::RwLock};
use std::{
collections::HashMap,
sync::{atomic::Ordering, Arc, RwLock},
};

use serde::Serialize;
use std::sync::atomic::AtomicBool;
use uuid::Uuid;

/// Store Types
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Serialize)]
pub struct Session {
#[serde(rename = "sessionID")]
pub session_id: Uuid,
#[serde(rename = "userID")]
pub user_id: Uuid,
pub username: String,
pub connected: bool,
pub connected: AtomicBool,
}
impl Session {
pub fn new(username: String) -> Self {
Self {
session_id: Uuid::new_v4(),
user_id: Uuid::new_v4(),
username,
connected: true,
connected: AtomicBool::new(true),
}
}
pub fn set_connected(&self, connected: bool) {
self.connected.store(connected, Ordering::SeqCst);
}
}
impl Clone for Session {
fn clone(&self) -> Self {
Self {
session_id: self.session_id.clone(),
user_id: self.user_id.clone(),
username: self.username.clone(),
connected: AtomicBool::new(self.connected.load(Ordering::SeqCst)),
}
}
}
Expand All @@ -29,6 +48,42 @@ pub struct Message {
}

#[derive(Default)]
pub struct Sessions(pub RwLock<HashMap<Uuid, Session>>);
pub struct Sessions(RwLock<HashMap<Uuid, Arc<Session>>>);

impl Sessions {
pub fn get_all_other_sessions(&self, user_id: Uuid) -> Vec<Arc<Session>> {
self.0
.read()
.unwrap()
.values()
.filter(|s| s.user_id != user_id)
.cloned()
.collect()
}

pub fn get(&self, user_id: Uuid) -> Option<Arc<Session>> {
self.0.read().unwrap().get(&user_id).cloned()
}

pub fn add(&self, session: Arc<Session>) {
self.0.write().unwrap().insert(session.session_id, session);
}
}
#[derive(Default)]
pub struct Messages(pub RwLock<Vec<Message>>);
pub struct Messages(RwLock<Vec<Message>>);

impl Messages {
pub fn add(&self, message: Message) {
self.0.write().unwrap().push(message);
}

pub fn get_all_for_user(&self, user_id: Uuid) -> Vec<Message> {
self.0
.read()
.unwrap()
.iter()
.filter(|m| m.from == user_id || m.to == user_id)
.cloned()
.collect()
}
}
Loading

0 comments on commit f509ba7

Please sign in to comment.