Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(socketio/extensions): rework extensions and add an Extension decorator and an HttpExtension decorator #309

Merged
merged 25 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9553267
feat(socketio/extensions): use `RwLock<HashMap>` rather than `DashMap`
Totodore Apr 19, 2024
46c7a1b
chore(bench): add bencher ci
Totodore Apr 19, 2024
d7d6a3f
fix: socketioxide benches with `Bytes`
Totodore Apr 19, 2024
c307a73
chore(bench): fix ci name
Totodore Apr 19, 2024
2894317
chore(bench): add RUSTFLAG for testing
Totodore Apr 19, 2024
44d73ba
fix: engineioxide benches
Totodore Apr 19, 2024
191e3fa
chore(bench): remove matrix test
Totodore Apr 19, 2024
66aeef9
chore(bench): add groups
Totodore Apr 19, 2024
82fefb5
chore(bench): improve extensions bench
Totodore Apr 19, 2024
3652d0a
Merge branch 'bencher' into feat-extensions-rework
Totodore Apr 20, 2024
df530f1
Merge branch 'main' into feat-extensions-rework
Totodore Apr 20, 2024
164a7ae
feat(socketio/extract): refactor extract mod
Totodore Apr 20, 2024
f6008a5
feat(socketio/extract): add `(Maybe)(Http)Extension` extractors
Totodore Apr 20, 2024
ecff81a
docs(example): update examples with `Extension` extractor
Totodore Apr 20, 2024
5070dd0
test(socketio/extract): add tests for `Extension` and `MaybeExtension`
Totodore Apr 20, 2024
a744f7b
docs(example) fmt chat example
Totodore Apr 20, 2024
bf8daab
Merge branch 'main' into feat-extensions-rework
Totodore Apr 20, 2024
f7106db
Merge branch 'main' into feat-extensions-rework
Totodore Apr 21, 2024
50372f2
test(socketio): fix extractors test
Totodore Apr 21, 2024
76c72ca
doc(socketio): improve doc for socketioxide
Totodore Apr 21, 2024
df94f5c
test(socketio): increase timeout
Totodore Apr 21, 2024
5888b37
Merge branch 'main' into feat-extensions-rework
Totodore May 6, 2024
1805c10
Merge branch 'main' into feat-extensions-rework
Totodore May 10, 2024
334e32a
doc(socketio): improve doc
Totodore May 21, 2024
f2c193a
Merge branch 'main' into feat-extensions-rework
Totodore Jun 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions examples/chat/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::atomic::AtomicUsize;

use serde::{Deserialize, Serialize};
use socketioxide::{
extract::{Data, SocketRef, State},
extract::{Data, Extension, SocketRef, State},
SocketIo,
};
use tower::ServiceBuilder;
Expand Down Expand Up @@ -59,14 +59,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let (layer, io) = SocketIo::builder().with_state(UserCnt::new()).build_layer();

io.ns("/", |s: SocketRef| {
s.on("new message", |s: SocketRef, Data::<String>(msg)| {
let username = s.extensions.get::<Username>().unwrap().clone();
let msg = Res::Message {
username,
message: msg,
};
s.broadcast().emit("new message", msg).ok();
});
s.on(
"new message",
|s: SocketRef, Data::<String>(msg), Extension::<Username>(username)| {
let msg = Res::Message {
username,
message: msg,
};
s.broadcast().emit("new message", msg).ok();
},
);

s.on(
"add user",
Expand All @@ -86,30 +88,31 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
},
);

s.on("typing", |s: SocketRef| {
let username = s.extensions.get::<Username>().unwrap().clone();
s.on("typing", |s: SocketRef, Extension::<Username>(username)| {
s.broadcast()
.emit("typing", Res::Username { username })
.ok();
});

s.on("stop typing", |s: SocketRef| {
let username = s.extensions.get::<Username>().unwrap().clone();
s.broadcast()
.emit("stop typing", Res::Username { username })
.ok();
});
s.on(
"stop typing",
|s: SocketRef, Extension::<Username>(username)| {
s.broadcast()
.emit("stop typing", Res::Username { username })
.ok();
},
);

s.on_disconnect(|s: SocketRef, user_cnt: State<UserCnt>| {
if let Some(username) = s.extensions.get::<Username>() {
s.on_disconnect(
|s: SocketRef, user_cnt: State<UserCnt>, Extension::<Username>(username)| {
let num_users = user_cnt.remove_user();
let res = Res::UserEvent {
num_users,
username: username.clone(),
username,
};
s.broadcast().emit("user left", res).ok();
}
});
},
);
});

let app = axum::Router::new()
Expand Down
1 change: 1 addition & 0 deletions examples/private-messaging/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ edition = "2021"
socketioxide = { path = "../../socketioxide", features = [
"extensions",
"state",
"tracing",
] }
axum.workspace = true
tokio = { workspace = true, features = ["rt-multi-thread", "macros"] }
Expand Down
119 changes: 61 additions & 58 deletions examples/private-messaging/src/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::sync::{atomic::Ordering, Arc};

use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use socketioxide::extract::{Data, SocketRef, State, TryData};
use socketioxide::extract::{Data, Extension, SocketRef, State};
use uuid::Uuid;

use crate::store::{Message, Messages, Session, Sessions};
Expand All @@ -22,12 +24,19 @@ struct UserConnectedRes {
messages: Vec<Message>,
}

#[derive(Debug, Serialize, Clone)]
struct UserDisconnectedRes {
#[serde(rename = "userID")]
user_id: Uuid,
username: String,
}

impl UserConnectedRes {
fn new(session: &Session, messages: Vec<Message>) -> Self {
Self {
user_id: session.user_id,
username: session.username.clone(),
connected: session.connected,
connected: session.connected.load(Ordering::SeqCst),
messages,
}
}
Expand All @@ -38,86 +47,80 @@ struct PrivateMessageReq {
content: String,
}

pub fn on_connection(s: SocketRef) {
pub fn on_connection(
s: SocketRef,
Extension::<Arc<Session>>(session): Extension<Arc<Session>>,
State(sessions): State<Sessions>,
State(msgs): State<Messages>,
) {
s.emit("session", (*session).clone()).unwrap();

let users = sessions
.get_all_other_sessions(session.user_id)
.into_iter()
.map(|session| {
let messages = msgs.get_all_for_user(session.user_id);
UserConnectedRes::new(&session, messages)
})
.collect::<Vec<_>>();

s.emit("users", [users]).unwrap();

let res = UserConnectedRes::new(&session, vec![]);
s.broadcast().emit("user connected", res).unwrap();

s.on(
"private message",
|s: SocketRef, Data(PrivateMessageReq { to, content }), State(Messages(msg))| {
let user_id = s.extensions.get::<Session>().unwrap().user_id;
|s: SocketRef,
Data(PrivateMessageReq { to, content }),
State::<Messages>(msgs),
Extension::<Arc<Session>>(session)| {
let message = Message {
from: user_id,
from: session.user_id,
to,
content,
};
msg.write().unwrap().push(message.clone());
msgs.add(message.clone());
s.within(to.to_string())
.emit("private message", message)
.ok();
},
);

s.on_disconnect(|s: SocketRef, State(Sessions(sessions))| {
let mut session = s.extensions.get::<Session>().unwrap().clone();
session.connected = false;

sessions
.write()
.unwrap()
.get_mut(&session.session_id)
.unwrap()
.connected = false;

s.broadcast().emit("user disconnected", session).ok();
s.on_disconnect(|s: SocketRef, Extension::<Arc<Session>>(session)| {
session.set_connected(false);
s.broadcast()
.emit(
"user disconnected",
UserDisconnectedRes {
user_id: session.user_id,
username: session.username.clone(),
},
)
.ok();
});
}

/// Handles the connection of a new user
/// Handles the connection of a new user.
/// Be careful to not emit anything to the user before the authentication is done.
pub fn authenticate_middleware(
s: SocketRef,
TryData(auth): TryData<Auth>,
State(Sessions(session_state)): State<Sessions>,
State(Messages(msg_state)): State<Messages>,
Data(auth): Data<Auth>,
State(sessions): State<Sessions>,
) -> Result<(), anyhow::Error> {
let auth = auth?;
let mut sessions = session_state.write().unwrap();
if let Some(session) = auth.session_id.and_then(|id| sessions.get_mut(&id)) {
session.connected = true;
let session = if let Some(session) = auth.session_id.and_then(|id| sessions.get(id)) {
session.set_connected(true);
s.extensions.insert(session.clone());
session
} else {
let username = auth.username.ok_or(anyhow!("invalid username"))?;
let session = Session::new(username);
let session = Arc::new(Session::new(username));
s.extensions.insert(session.clone());

sessions.insert(session.session_id, session);
sessions.add(session.clone());
session
};
drop(sessions);

let session = s.extensions.get::<Session>().unwrap();

s.join(session.user_id.to_string()).ok();
s.emit("session", session.clone())?;

let users = session_state
.read()
.unwrap()
.iter()
.filter(|(id, _)| id != &&session.session_id)
.map(|(_, session)| {
let messages = msg_state
.read()
.unwrap()
.iter()
.filter(|message| message.to == session.user_id || message.from == session.user_id)
.cloned()
.collect();

UserConnectedRes::new(session, messages)
})
.collect::<Vec<_>>();

s.emit("users", [users])?;

let res = UserConnectedRes::new(&session, vec![]);
s.join(session.user_id.to_string())?;

s.broadcast().emit("user connected", res)?;
Ok(())
}
67 changes: 61 additions & 6 deletions examples/private-messaging/src/store.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,42 @@
use std::{collections::HashMap, sync::RwLock};
use std::{
collections::HashMap,
sync::{atomic::Ordering, Arc, RwLock},
};

use serde::Serialize;
use std::sync::atomic::AtomicBool;
use uuid::Uuid;

/// Store Types
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Serialize)]
pub struct Session {
#[serde(rename = "sessionID")]
pub session_id: Uuid,
#[serde(rename = "userID")]
pub user_id: Uuid,
pub username: String,
pub connected: bool,
pub connected: AtomicBool,
}
impl Session {
pub fn new(username: String) -> Self {
Self {
session_id: Uuid::new_v4(),
user_id: Uuid::new_v4(),
username,
connected: true,
connected: AtomicBool::new(true),
}
}
pub fn set_connected(&self, connected: bool) {
self.connected.store(connected, Ordering::SeqCst);
}
}
impl Clone for Session {
fn clone(&self) -> Self {
Self {
session_id: self.session_id.clone(),
user_id: self.user_id.clone(),
username: self.username.clone(),
connected: AtomicBool::new(self.connected.load(Ordering::SeqCst)),
}
}
}
Expand All @@ -29,6 +48,42 @@ pub struct Message {
}

#[derive(Default)]
pub struct Sessions(pub RwLock<HashMap<Uuid, Session>>);
pub struct Sessions(RwLock<HashMap<Uuid, Arc<Session>>>);

impl Sessions {
pub fn get_all_other_sessions(&self, user_id: Uuid) -> Vec<Arc<Session>> {
self.0
.read()
.unwrap()
.values()
.filter(|s| s.user_id != user_id)
.cloned()
.collect()
}

pub fn get(&self, user_id: Uuid) -> Option<Arc<Session>> {
self.0.read().unwrap().get(&user_id).cloned()
}

pub fn add(&self, session: Arc<Session>) {
self.0.write().unwrap().insert(session.session_id, session);
}
}
#[derive(Default)]
pub struct Messages(pub RwLock<Vec<Message>>);
pub struct Messages(RwLock<Vec<Message>>);

impl Messages {
pub fn add(&self, message: Message) {
self.0.write().unwrap().push(message);
}

pub fn get_all_for_user(&self, user_id: Uuid) -> Vec<Message> {
self.0
.read()
.unwrap()
.iter()
.filter(|m| m.from == user_id || m.to == user_id)
.cloned()
.collect()
}
}
Loading
Loading