From b3b23bf83b4722c738bd16eee53fda40a9e1cae7 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Mon, 5 Feb 2024 19:07:59 +0100 Subject: [PATCH 1/4] add TowerService build and notify on session close --- examples/examples/jsonrpsee_as_service.rs | 70 +++++++++++++++++------ server/src/future.rs | 26 ++++++++- server/src/server.rs | 45 +++++++++++++-- server/src/transport/ws.rs | 6 +- 4 files changed, 124 insertions(+), 23 deletions(-) diff --git a/examples/examples/jsonrpsee_as_service.rs b/examples/examples/jsonrpsee_as_service.rs index c738c05259..e6b60e9e9a 100644 --- a/examples/examples/jsonrpsee_as_service.rs +++ b/examples/examples/jsonrpsee_as_service.rs @@ -36,6 +36,7 @@ use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use futures::FutureExt; use hyper::header::AUTHORIZATION; use hyper::server::conn::AddrStream; use hyper::HeaderMap; @@ -45,16 +46,18 @@ use jsonrpsee::proc_macros::rpc; use jsonrpsee::server::middleware::rpc::{ResponseFuture, RpcServiceBuilder, RpcServiceT}; use jsonrpsee::server::{stop_channel, ServerHandle, StopHandle, TowerServiceBuilder}; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned, Request}; -use jsonrpsee::ws_client::HeaderValue; +use jsonrpsee::ws_client::{HeaderValue, WsClientBuilder}; use jsonrpsee::{MethodResponse, Methods}; use tower::Service; use tower_http::cors::CorsLayer; use tracing_subscriber::util::SubscriberInitExt; -#[derive(Default, Clone)] +#[derive(Default, Clone, Debug)] struct Metrics { - ws_connections: Arc, - http_connections: Arc, + opened_ws_connections: Arc, + closed_ws_connections: Arc, + http_calls: Arc, + success_http_calls: Arc, } #[derive(Clone)] @@ -106,7 +109,9 @@ async fn main() -> anyhow::Result<()> { let filter = tracing_subscriber::EnvFilter::try_from_default_env()?; tracing_subscriber::FmtSubscriber::builder().with_env_filter(filter).finish().try_init()?; - let handle = run_server(); + let metrics = Metrics::default(); + + let handle = run_server(metrics.clone()); tokio::spawn(handle.stopped()); { @@ -117,6 +122,14 @@ async fn main() -> anyhow::Result<()> { tracing::info!("response: {x}"); } + { + let client = WsClientBuilder::default().build("ws://127.0.0.1:9944").await.unwrap(); + + // Fails because the authorization header is missing. + let x = client.trusted_call().await.unwrap_err(); + tracing::info!("response: {x}"); + } + { let mut headers = HeaderMap::new(); headers.insert(AUTHORIZATION, HeaderValue::from_static("don't care in this example")); @@ -127,10 +140,12 @@ async fn main() -> anyhow::Result<()> { tracing::info!("response: {x}"); } + tracing::info!("{:?}", metrics); + Ok(()) } -fn run_server() -> ServerHandle { +fn run_server(metrics: Metrics) -> ServerHandle { use hyper::service::{make_service_fn, service_fn}; let addr = SocketAddr::from(([127, 0, 0, 1], 9944)); @@ -159,7 +174,7 @@ fn run_server() -> ServerHandle { let per_conn = PerConnection { methods: ().into_rpc().into(), stop_handle: stop_handle.clone(), - metrics: Metrics::default(), + metrics, svc_builder: jsonrpsee::server::Server::builder() .set_http_middleware(tower::ServiceBuilder::new().layer(CorsLayer::permissive())) .max_connections(33) @@ -183,17 +198,40 @@ fn run_server() -> ServerHandle { AuthorizationMiddleware { inner: service, headers: headers.clone(), transport_label } }); - let mut svc = svc_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle); + if is_websocket { + // Utilize the session close future to know when the actual WebSocket + // session was closed. + let (mut svc, session_close) = svc_builder + .set_rpc_middleware(rpc_middleware) + .build_and_notify_on_session_close(methods, stop_handle); + + // A little bit weird API but the response to HTTP request must be returned below + // and we spawn a task to register when the session is closed. + tokio::spawn(async move { + session_close.await; + metrics.closed_ws_connections.fetch_add(1, Ordering::Relaxed); + }); + + async move { + metrics.opened_ws_connections.fetch_add(1, Ordering::Relaxed); + svc.call(req).await + } + .boxed() + } else { + // HTTP. + let mut svc = svc_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle); + + async move { + metrics.http_calls.fetch_add(1, Ordering::Relaxed); + let rp = svc.call(req).await; + + if rp.is_ok() { + metrics.success_http_calls.fetch_add(1, Ordering::Relaxed); + } - async move { - // You can't determine whether the websocket upgrade handshake failed or not here. - let rp = svc.call(req).await; - if is_websocket { - metrics.ws_connections.fetch_add(1, Ordering::Relaxed); - } else { - metrics.http_connections.fetch_add(1, Ordering::Relaxed); + rp } - rp + .boxed() } })) } diff --git a/server/src/future.rs b/server/src/future.rs index d187ef58e2..9d413581b6 100644 --- a/server/src/future.rs +++ b/server/src/future.rs @@ -30,7 +30,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use futures_util::{Stream, StreamExt}; +use futures_util::{Future, Stream, StreamExt}; use pin_project::pin_project; use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Interval; @@ -157,3 +157,27 @@ impl Stream for IntervalStream { } } } + +#[derive(Debug, Clone)] +pub(crate) struct SessionCloseTx(tokio::sync::mpsc::Sender<()>); + +/// A future that resolves when the a connection +/// has been closed. +#[derive(Debug)] +pub struct SessionCloseFuture(tokio::sync::mpsc::Receiver<()>); + +impl Future for SessionCloseFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.0.poll_recv(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(_) => Poll::Ready(()), + } + } +} + +pub(crate) fn on_session_close() -> (SessionCloseTx, SessionCloseFuture) { + let (tx, rx) = tokio::sync::mpsc::channel(1); + (SessionCloseTx(tx), SessionCloseFuture(rx)) +} diff --git a/server/src/server.rs b/server/src/server.rs index 2adb48b23a..022a319528 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -33,7 +33,7 @@ use std::sync::Arc; use std::task::Poll; use std::time::Duration; -use crate::future::{ConnectionGuard, ServerHandle, StopHandle}; +use crate::future::{on_session_close, ConnectionGuard, ServerHandle, SessionCloseFuture, SessionCloseTx, StopHandle}; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::transport::ws::BackgroundTaskParams; use crate::transport::{http, ws}; @@ -501,11 +501,42 @@ impl TowerServiceBuilder, + stop_handle: StopHandle, + ) -> (TowerService, SessionCloseFuture) { + let conn_id = self.conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let (tx, rx) = on_session_close(); + + let rpc_middleware = TowerServiceNoHttp { + rpc_middleware: self.rpc_middleware, + inner: ServiceData { + methods: methods.into(), + stop_handle, + conn_id, + conn_guard: self.conn_guard, + server_cfg: self.server_cfg, + }, + on_session_close: Some(tx), + }; + + (TowerService { rpc_middleware, http_middleware: self.http_middleware }, rx) + } + /// Configure the connection id. /// /// This is incremented every time `build` is called. @@ -617,18 +648,18 @@ impl Builder { /// impl<'a, S> RpcServiceT<'a> for MyMiddleware /// where S: RpcServiceT<'a> + Send + Sync + Clone + 'static, /// { - /// type Future = BoxFuture<'a, MethodResponse>; - /// + /// type Future = BoxFuture<'a, MethodResponse>; + /// /// fn call(&self, req: Request<'a>) -> Self::Future { /// tracing::info!("MyMiddleware processed call {}", req.method); /// let count = self.count.clone(); - /// let service = self.service.clone(); + /// let service = self.service.clone(); /// /// Box::pin(async move { /// let rp = service.call(req).await; /// // Modify the state. /// count.fetch_add(1, Ordering::Relaxed); - /// rp + /// rp /// }) /// } /// } @@ -979,6 +1010,7 @@ where pub struct TowerServiceNoHttp { inner: ServiceData, rpc_middleware: RpcServiceBuilder, + on_session_close: Option, } impl hyper::service::Service> for TowerServiceNoHttp @@ -1004,6 +1036,7 @@ where let conn_guard = &self.inner.conn_guard; let stop_handle = self.inner.stop_handle.clone(); let conn_id = self.inner.conn_id; + let on_session_close = self.on_session_close.take(); tracing::trace!(target: LOG_TARGET, "{:?}", request); @@ -1076,6 +1109,7 @@ where sink, rx, pending_calls_completed, + on_session_close, }; ws::background_task(params).await; @@ -1176,6 +1210,7 @@ fn process_connection<'a, RpcMiddleware, HttpMiddleware, U>( conn_guard: conn_guard.clone(), }, rpc_middleware, + on_session_close: None, }; let service = http_middleware.service(tower_service); diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 2480529847..6986def6a1 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use std::time::Instant; -use crate::future::IntervalStream; +use crate::future::{IntervalStream, SessionCloseTx}; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::server::{handle_rpc_call, ConnectionState, ServerConfig}; use crate::{PingConfig, LOG_TARGET}; @@ -56,6 +56,7 @@ pub(crate) struct BackgroundTaskParams { pub(crate) sink: MethodSink, pub(crate) rx: mpsc::Receiver, pub(crate) pending_calls_completed: mpsc::Receiver<()>, + pub(crate) on_session_close: Option, } pub(crate) async fn background_task(params: BackgroundTaskParams) @@ -71,6 +72,7 @@ where sink, rx, pending_calls_completed, + on_session_close, } = params; let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; @@ -169,6 +171,7 @@ where graceful_shutdown(result, pending_calls_completed, ws_stream, conn_tx, send_task_handle).await; drop(conn); + drop(on_session_close); } /// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`. @@ -445,6 +448,7 @@ where sink, rx, pending_calls_completed, + on_session_close: None, }; background_task(params).await; From e29f43e1e4e9f6fd6d9f0428ae3c614ab51c9ce3 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 6 Feb 2024 10:41:17 +0100 Subject: [PATCH 2/4] refactor the API --- examples/examples/jsonrpsee_as_service.rs | 12 ++--- server/Cargo.toml | 2 +- server/src/future.rs | 36 ++++++++++----- server/src/server.rs | 53 +++++++++-------------- server/src/transport/ws.rs | 11 +++-- 5 files changed, 62 insertions(+), 52 deletions(-) diff --git a/examples/examples/jsonrpsee_as_service.rs b/examples/examples/jsonrpsee_as_service.rs index e6b60e9e9a..a87da4f730 100644 --- a/examples/examples/jsonrpsee_as_service.rs +++ b/examples/examples/jsonrpsee_as_service.rs @@ -198,30 +198,31 @@ fn run_server(metrics: Metrics) -> ServerHandle { AuthorizationMiddleware { inner: service, headers: headers.clone(), transport_label } }); + let mut svc = svc_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle); + if is_websocket { // Utilize the session close future to know when the actual WebSocket // session was closed. - let (mut svc, session_close) = svc_builder - .set_rpc_middleware(rpc_middleware) - .build_and_notify_on_session_close(methods, stop_handle); + let session_close = svc.on_session_closed(); // A little bit weird API but the response to HTTP request must be returned below // and we spawn a task to register when the session is closed. tokio::spawn(async move { session_close.await; + tracing::info!("Closed WebSocket connection"); metrics.closed_ws_connections.fetch_add(1, Ordering::Relaxed); }); async move { + tracing::info!("Opened WebSocket connection"); metrics.opened_ws_connections.fetch_add(1, Ordering::Relaxed); svc.call(req).await } .boxed() } else { // HTTP. - let mut svc = svc_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle); - async move { + tracing::info!("Opened HTTP connection"); metrics.http_calls.fetch_add(1, Ordering::Relaxed); let rp = svc.call(req).await; @@ -229,6 +230,7 @@ fn run_server(metrics: Metrics) -> ServerHandle { metrics.success_http_calls.fetch_add(1, Ordering::Relaxed); } + tracing::info!("Closed HTTP connection"); rp } .boxed() diff --git a/server/Cargo.toml b/server/Cargo.toml index 266497c840..5820c80584 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -23,7 +23,7 @@ serde_json = { version = "1", features = ["raw_value"] } soketto = { version = "0.7.1", features = ["http"] } tokio = { version = "1.16", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.7", features = ["compat"] } -tokio-stream = "0.1.7" +tokio-stream = { version = "0.1.7", features = ["sync"] } hyper = { version = "0.14", features = ["server", "http1", "http2"] } tower = "0.4.13" thiserror = "1" diff --git a/server/src/future.rs b/server/src/future.rs index 9d413581b6..bc07d9fb28 100644 --- a/server/src/future.rs +++ b/server/src/future.rs @@ -34,6 +34,7 @@ use futures_util::{Future, Stream, StreamExt}; use pin_project::pin_project; use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Interval; +use tokio_stream::wrappers::BroadcastStream; /// Create channel to determine whether /// the server shall continue to run or not. @@ -159,25 +160,40 @@ impl Stream for IntervalStream { } #[derive(Debug, Clone)] -pub(crate) struct SessionCloseTx(tokio::sync::mpsc::Sender<()>); +pub(crate) struct SessionClose(tokio::sync::broadcast::Sender<()>); -/// A future that resolves when the a connection -/// has been closed. +impl SessionClose { + pub(crate) fn close(self) { + let _ = self.0.send(()); + } + + pub(crate) fn closed(&self) -> SessionClosedFuture { + SessionClosedFuture(BroadcastStream::new(self.0.subscribe())) + } +} + +/// A future that resolves when the connection has been closed. #[derive(Debug)] -pub struct SessionCloseFuture(tokio::sync::mpsc::Receiver<()>); +pub struct SessionClosedFuture(BroadcastStream<()>); -impl Future for SessionCloseFuture { +impl Future for SessionClosedFuture { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.0.poll_recv(cx) { + match self.0.poll_next_unpin(cx) { Poll::Pending => Poll::Pending, - Poll::Ready(_) => Poll::Ready(()), + // A message is only sent when + Poll::Ready(x) => { + tracing::info!("{:?}", x); + Poll::Ready(()) + } } } } -pub(crate) fn on_session_close() -> (SessionCloseTx, SessionCloseFuture) { - let (tx, rx) = tokio::sync::mpsc::channel(1); - (SessionCloseTx(tx), SessionCloseFuture(rx)) +pub(crate) fn session_close() -> (SessionClose, SessionClosedFuture) { + // SessionClosedFuture is closed after one message has been recevied + // and max one message is handled then it's closed. + let (tx, rx) = tokio::sync::broadcast::channel(1); + (SessionClose(tx), SessionClosedFuture(BroadcastStream::new(rx))) } diff --git a/server/src/server.rs b/server/src/server.rs index 022a319528..ecae9b3c27 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -33,7 +33,7 @@ use std::sync::Arc; use std::task::Poll; use std::time::Duration; -use crate::future::{on_session_close, ConnectionGuard, ServerHandle, SessionCloseFuture, SessionCloseTx, StopHandle}; +use crate::future::{session_close, ConnectionGuard, ServerHandle, SessionClose, SessionClosedFuture, StopHandle}; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::transport::ws::BackgroundTaskParams; use crate::transport::{http, ws}; @@ -507,36 +507,6 @@ impl TowerServiceBuilder, - stop_handle: StopHandle, - ) -> (TowerService, SessionCloseFuture) { - let conn_id = self.conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let (tx, rx) = on_session_close(); - - let rpc_middleware = TowerServiceNoHttp { - rpc_middleware: self.rpc_middleware, - inner: ServiceData { - methods: methods.into(), - stop_handle, - conn_id, - conn_guard: self.conn_guard, - server_cfg: self.server_cfg, - }, - on_session_close: Some(tx), - }; - - (TowerService { rpc_middleware, http_middleware: self.http_middleware }, rx) - } - /// Configure the connection id. /// /// This is incremented every time `build` is called. @@ -972,6 +942,25 @@ pub struct TowerService { http_middleware: tower::ServiceBuilder, } +impl TowerService { + /// A future that returns when the connection has been closed. + /// + /// It's possible to call this many times but internally it uses + /// a bounded buffer of 4 such that if one creates more than 4 + /// SessionCloseFuture's. Then any of these 4 first futures + /// must be polled or dropped to make any progress. + pub fn on_session_closed(&mut self) -> SessionClosedFuture { + if let Some(n) = self.rpc_middleware.on_session_close.as_mut() { + // If it's called more then once another listener is created. + n.closed() + } else { + let (session_close, fut) = session_close(); + self.rpc_middleware.on_session_close = Some(session_close); + fut + } + } +} + impl hyper::service::Service> for TowerService where @@ -1010,7 +999,7 @@ where pub struct TowerServiceNoHttp { inner: ServiceData, rpc_middleware: RpcServiceBuilder, - on_session_close: Option, + on_session_close: Option, } impl hyper::service::Service> for TowerServiceNoHttp diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 6986def6a1..d3fa7a4ddd 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use std::time::Instant; -use crate::future::{IntervalStream, SessionCloseTx}; +use crate::future::{IntervalStream, SessionClose}; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::server::{handle_rpc_call, ConnectionState, ServerConfig}; use crate::{PingConfig, LOG_TARGET}; @@ -56,7 +56,7 @@ pub(crate) struct BackgroundTaskParams { pub(crate) sink: MethodSink, pub(crate) rx: mpsc::Receiver, pub(crate) pending_calls_completed: mpsc::Receiver<()>, - pub(crate) on_session_close: Option, + pub(crate) on_session_close: Option, } pub(crate) async fn background_task(params: BackgroundTaskParams) @@ -72,7 +72,7 @@ where sink, rx, pending_calls_completed, - on_session_close, + mut on_session_close, } = params; let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; @@ -171,7 +171,10 @@ where graceful_shutdown(result, pending_calls_completed, ws_stream, conn_tx, send_task_handle).await; drop(conn); - drop(on_session_close); + + if let Some(c) = on_session_close.take() { + c.close(); + } } /// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`. From c1ea188a37f016e9a2f51d007a693a447250a91d Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 6 Feb 2024 12:02:27 +0100 Subject: [PATCH 3/4] clarify docs --- server/src/server.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/server/src/server.rs b/server/src/server.rs index ecae9b3c27..2dfff2364c 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -945,10 +945,9 @@ pub struct TowerService { impl TowerService { /// A future that returns when the connection has been closed. /// - /// It's possible to call this many times but internally it uses - /// a bounded buffer of 4 such that if one creates more than 4 - /// SessionCloseFuture's. Then any of these 4 first futures - /// must be polled or dropped to make any progress. + /// This method must be called before every [`TowerService::call`] + /// because the `SessionClosedFuture` may already been consumed or + /// not used. pub fn on_session_closed(&mut self) -> SessionClosedFuture { if let Some(n) = self.rpc_middleware.on_session_close.as_mut() { // If it's called more then once another listener is created. From fc8eaf99dde9d57d97cd28798069974e35b756fa Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 6 Feb 2024 12:46:33 +0100 Subject: [PATCH 4/4] add test for on_session_close --- server/src/future.rs | 8 ++--- server/src/tests/helpers.rs | 71 +++++++++++++++++++++++++++++++++++-- server/src/tests/ws.rs | 27 +++++++++++++- 3 files changed, 98 insertions(+), 8 deletions(-) diff --git a/server/src/future.rs b/server/src/future.rs index bc07d9fb28..d26b9cb936 100644 --- a/server/src/future.rs +++ b/server/src/future.rs @@ -182,11 +182,9 @@ impl Future for SessionClosedFuture { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.0.poll_next_unpin(cx) { Poll::Pending => Poll::Pending, - // A message is only sent when - Poll::Ready(x) => { - tracing::info!("{:?}", x); - Poll::Ready(()) - } + // Only message is only sent and + // ignore can't keep up errors. + Poll::Ready(_) => Poll::Ready(()), } } } diff --git a/server/src/tests/helpers.rs b/server/src/tests/helpers.rs index 413aa7bd0d..968f22d81b 100644 --- a/server/src/tests/helpers.rs +++ b/server/src/tests/helpers.rs @@ -1,11 +1,18 @@ -use std::fmt; +use std::error::Error as StdError; use std::net::SocketAddr; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::{fmt, sync::atomic::AtomicUsize}; -use crate::{RpcModule, ServerBuilder, ServerHandle}; +use crate::{stop_channel, RpcModule, Server, ServerBuilder, ServerHandle}; +use futures_util::FutureExt; +use hyper::server::conn::AddrStream; +use jsonrpsee_core::server::Methods; use jsonrpsee_core::{DeserializeOwned, RpcResult, StringError}; use jsonrpsee_test_utils::TimeoutFutureExt; use jsonrpsee_types::{error::ErrorCode, ErrorObject, ErrorObjectOwned, Response, ResponseSuccess}; +use tower::Service; use tracing_subscriber::{EnvFilter, FmtSubscriber}; pub(crate) struct TestContext; @@ -194,3 +201,63 @@ impl From for ErrorObjectOwned { fn invalid_params() -> ErrorObjectOwned { ErrorCode::InvalidParams.into() } + +#[derive(Debug, Clone, Default)] +pub(crate) struct Metrics { + pub(crate) ws_sessions_opened: Arc, + pub(crate) ws_sessions_closed: Arc, +} + +pub(crate) fn ws_server_with_stats(metrics: Metrics) -> SocketAddr { + use hyper::service::{make_service_fn, service_fn}; + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (stop_handle, server_handle) = stop_channel(); + let stop_handle2 = stop_handle.clone(); + + // And a MakeService to handle each connection... + let make_service = make_service_fn(move |_conn: &AddrStream| { + let stop_handle = stop_handle2.clone(); + let metrics = metrics.clone(); + + async move { + Ok::<_, Box>(service_fn(move |req| { + let is_websocket = crate::ws::is_upgrade_request(&req); + let metrics = metrics.clone(); + let stop_handle = stop_handle.clone(); + + let mut svc = + Server::builder().max_connections(33).to_service_builder().build(Methods::new(), stop_handle); + + if is_websocket { + // This should work for each callback. + let session_close1 = svc.on_session_closed(); + let session_close2 = svc.on_session_closed(); + + tokio::spawn(async move { + metrics.ws_sessions_opened.fetch_add(1, Ordering::SeqCst); + tokio::join!(session_close2, session_close1); + metrics.ws_sessions_closed.fetch_add(1, Ordering::SeqCst); + }); + + async move { svc.call(req).await }.boxed() + } else { + // HTTP. + async move { svc.call(req).await }.boxed() + } + })) + } + }); + + let server = hyper::Server::bind(&addr).serve(make_service); + + let addr = server.local_addr(); + + tokio::spawn(async move { + let graceful = server.with_graceful_shutdown(async move { stop_handle.shutdown().await }); + graceful.await.unwrap(); + drop(server_handle) + }); + + addr +} diff --git a/server/src/tests/ws.rs b/server/src/tests/ws.rs index 4829f41a3a..f7c1381d63 100644 --- a/server/src/tests/ws.rs +++ b/server/src/tests/ws.rs @@ -24,9 +24,10 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use std::sync::atomic::Ordering; use std::time::Duration; -use crate::tests::helpers::{deser_call, init_logger, server_with_context}; +use crate::tests::helpers::{deser_call, init_logger, server_with_context, ws_server_with_stats, Metrics}; use crate::types::SubscriptionId; use crate::{BatchRequestConfig, RegisterMethodError}; use crate::{RpcModule, ServerBuilder}; @@ -874,6 +875,30 @@ async fn drop_client_with_pending_calls_works() { assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok()); } +#[tokio::test] +async fn server_notify_on_conn_close() { + init_logger(); + + let metrics = Metrics::default(); + let addr = ws_server_with_stats(metrics.clone()); + + let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap(); + + // Wait for the server to process + tokio::time::sleep(Duration::from_millis(100)).await; + + assert_eq!(metrics.ws_sessions_opened.load(Ordering::SeqCst), 1); + assert_eq!(metrics.ws_sessions_closed.load(Ordering::SeqCst), 0); + + client.close().with_default_timeout().await.unwrap().unwrap(); + + // Wait for the server to process + tokio::time::sleep(Duration::from_millis(100)).await; + + assert_eq!(metrics.ws_sessions_opened.load(Ordering::SeqCst), 1); + assert_eq!(metrics.ws_sessions_closed.load(Ordering::SeqCst), 1); +} + async fn server_with_infinite_call( timeout: Duration, tx: tokio::sync::mpsc::UnboundedSender<()>,