Skip to content

Commit

Permalink
feat(socketio/connect): Add middlewares to namespace (#280)
Browse files Browse the repository at this point in the history
* feat(socketio/handler): add a `with` middleware fn to any connect handler.

* feat(socketio/packet): custom `connect_error` packet

* feat(socketio/handler): return error from middleware.

* fix(clippy): async fn call rather than impl `Future`

* fix(fmt)

* chore: bump MSRV to 1.75.0

* test(socketio/connect): add middleware tests

* feat(socketio/connect): connect to ns only after connect handler result.

* fix(fmt)

* fix(clippy)

* test(socketio/ack): fix ack test for new `connect` handler behaviour

* test(io): unused `Result` lint

* feat(socketio/conect): correct behaviour with connect after middleware

* Revert "test(socketio/ack): fix ack test for new `connect` handler behaviour"

This reverts commit 53ab208.

* test(socketio/connect):  fix ws message assertion

* test: minor improvements

* doc(socketio/connect): improve doc and code readability

* feat(socketio/connect): emit extractor errors in middlewares

* chore(bench/heaptrack): add middleware to bench

* doc(socketio/connect): wip doc

* doc(socketio/connect): wip doc

* doc(socketio/connect): wip doc

* test(socket): set connected for dummy socket

* test(socket): add test for connect status

* feat(socketio/connect): block emission if socket is not connected

* feat(socketio/socket): disconnect status before calling `disconnect` handler.

* doc(socketio/connect): document middlewares

* chore(bench): remove middleware from bench

* doc(socketio/connect): specify middleware behavior for `Data` extractor

* doc(socketio/connect): specify middleware behavior for `Data` extractor
  • Loading branch information
Totodore authored Mar 19, 2024
1 parent 39d700a commit 4ffd473
Show file tree
Hide file tree
Showing 15 changed files with 581 additions and 78 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[workspace.package]
version = "0.11.0"
edition = "2021"
rust-version = "1.67.0"
rust-version = "1.75.0"
authors = ["Théodore Prévot <"]
repository = "https://github.com/totodore/socketioxide"
homepage = "https://github.com/totodore/socketioxide"
Expand Down
24 changes: 6 additions & 18 deletions examples/private-messaging/src/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use socketioxide::extract::{Data, SocketRef, State, TryData};
use tracing::error;
use uuid::Uuid;

use crate::store::{Message, Messages, Session, Sessions};
Expand Down Expand Up @@ -39,18 +38,7 @@ struct PrivateMessageReq {
content: String,
}

pub fn on_connection(
s: SocketRef,
TryData(auth): TryData<Auth>,
sessions: State<Sessions>,
msgs: State<Messages>,
) {
if let Err(e) = session_connect(&s, auth, sessions.0, msgs.0) {
error!("Failed to connect: {:?}", e);
s.disconnect().ok();
return;
}

pub fn on_connection(s: SocketRef) {
s.on(
"private message",
|s: SocketRef, Data(PrivateMessageReq { to, content }), State(Messages(msg))| {
Expand Down Expand Up @@ -83,11 +71,11 @@ pub fn on_connection(
}

/// Handles the connection of a new user
fn session_connect(
s: &SocketRef,
auth: Result<Auth, serde_json::Error>,
Sessions(session_state): &Sessions,
Messages(msg_state): &Messages,
pub fn authenticate_middleware(
s: SocketRef,
TryData(auth): TryData<Auth>,
State(Sessions(session_state)): State<Sessions>,
State(Messages(msg_state)): State<Messages>,
) -> Result<(), anyhow::Error> {
let auth = auth?;
let mut sessions = session_state.write().unwrap();
Expand Down
7 changes: 5 additions & 2 deletions examples/private-messaging/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use socketioxide::SocketIo;
use socketioxide::{handler::ConnectHandler, SocketIo};
use tower::ServiceBuilder;
use tower_http::{cors::CorsLayer, services::ServeDir};
use tracing::info;
Expand All @@ -22,7 +22,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.with_state(Messages::default())
.build_layer();

io.ns("/", handlers::on_connection);
io.ns(
"/",
handlers::on_connection.with(handlers::authenticate_middleware),
);

let app = axum::Router::new()
.nest_service("/", ServeDir::new("dist"))
Expand Down
24 changes: 15 additions & 9 deletions socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,21 @@ impl<A: Adapter> Client<A> {
#[cfg(feature = "tracing")]
tracing::debug!("auth: {:?}", auth);

let sid = esocket.id;
if let Some(ns) = self.get_ns(ns_path) {
ns.connect(sid, esocket.clone(), auth, self.config.clone())?;

// cancel the connect timeout task for v5
if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
tx.send(()).ok();
}

let esocket = esocket.clone();
let config = self.config.clone();
tokio::spawn(async move {
if ns
.connect(esocket.id, esocket.clone(), auth, config)
.await
.is_ok()
{
// cancel the connect timeout task for v5
if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
tx.send(()).ok();
}
}
});
Ok(())
} else if ProtocolVersion::from(esocket.protocol) == ProtocolVersion::V4 && ns_path == "/" {
#[cfg(feature = "tracing")]
Expand All @@ -64,7 +70,7 @@ impl<A: Adapter> Client<A> {
esocket.close(EIoDisconnectReason::TransportClose);
Ok(())
} else {
let packet = Packet::invalid_namespace(ns_path).into();
let packet = Packet::connect_error(ns_path, "Invalid namespace").into();
if let Err(_e) = esocket.emit(packet) {
#[cfg(feature = "tracing")]
tracing::error!("error while sending invalid namespace packet: {}", _e);
Expand Down
2 changes: 2 additions & 0 deletions socketioxide/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub enum Error {
Adapter(#[from] AdapterError),
}

pub(crate) struct ConnectFail;

/// Error type for ack operations.
#[derive(thiserror::Error, Debug)]
pub enum AckError<T> {
Expand Down
Loading

0 comments on commit 4ffd473

Please sign in to comment.