Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): add TowerService::on_session_close #1284

Merged
merged 4 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions examples/examples/jsonrpsee_as_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use futures::FutureExt;
use hyper::header::AUTHORIZATION;
use hyper::server::conn::AddrStream;
use hyper::HeaderMap;
Expand All @@ -45,16 +46,18 @@ use jsonrpsee::proc_macros::rpc;
use jsonrpsee::server::middleware::rpc::{ResponseFuture, RpcServiceBuilder, RpcServiceT};
use jsonrpsee::server::{stop_channel, ServerHandle, StopHandle, TowerServiceBuilder};
use jsonrpsee::types::{ErrorObject, ErrorObjectOwned, Request};
use jsonrpsee::ws_client::HeaderValue;
use jsonrpsee::ws_client::{HeaderValue, WsClientBuilder};
use jsonrpsee::{MethodResponse, Methods};
use tower::Service;
use tower_http::cors::CorsLayer;
use tracing_subscriber::util::SubscriberInitExt;

#[derive(Default, Clone)]
#[derive(Default, Clone, Debug)]
struct Metrics {
ws_connections: Arc<AtomicUsize>,
http_connections: Arc<AtomicUsize>,
opened_ws_connections: Arc<AtomicUsize>,
closed_ws_connections: Arc<AtomicUsize>,
http_calls: Arc<AtomicUsize>,
success_http_calls: Arc<AtomicUsize>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -106,7 +109,9 @@ async fn main() -> anyhow::Result<()> {
let filter = tracing_subscriber::EnvFilter::try_from_default_env()?;
tracing_subscriber::FmtSubscriber::builder().with_env_filter(filter).finish().try_init()?;

let handle = run_server();
let metrics = Metrics::default();

let handle = run_server(metrics.clone());
tokio::spawn(handle.stopped());

{
Expand All @@ -117,6 +122,14 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("response: {x}");
}

{
let client = WsClientBuilder::default().build("ws://127.0.0.1:9944").await.unwrap();

// Fails because the authorization header is missing.
let x = client.trusted_call().await.unwrap_err();
tracing::info!("response: {x}");
}

{
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("don't care in this example"));
Expand All @@ -127,10 +140,12 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("response: {x}");
}

tracing::info!("{:?}", metrics);

Ok(())
}

fn run_server() -> ServerHandle {
fn run_server(metrics: Metrics) -> ServerHandle {
use hyper::service::{make_service_fn, service_fn};

let addr = SocketAddr::from(([127, 0, 0, 1], 9944));
Expand Down Expand Up @@ -159,7 +174,7 @@ fn run_server() -> ServerHandle {
let per_conn = PerConnection {
methods: ().into_rpc().into(),
stop_handle: stop_handle.clone(),
metrics: Metrics::default(),
metrics,
svc_builder: jsonrpsee::server::Server::builder()
.set_http_middleware(tower::ServiceBuilder::new().layer(CorsLayer::permissive()))
.max_connections(33)
Expand All @@ -185,15 +200,40 @@ fn run_server() -> ServerHandle {

let mut svc = svc_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle);

async move {
// You can't determine whether the websocket upgrade handshake failed or not here.
let rp = svc.call(req).await;
if is_websocket {
metrics.ws_connections.fetch_add(1, Ordering::Relaxed);
} else {
metrics.http_connections.fetch_add(1, Ordering::Relaxed);
if is_websocket {
// Utilize the session close future to know when the actual WebSocket
// session was closed.
let session_close = svc.on_session_closed();

// A little bit weird API but the response to HTTP request must be returned below
// and we spawn a task to register when the session is closed.
tokio::spawn(async move {
session_close.await;
tracing::info!("Closed WebSocket connection");
metrics.closed_ws_connections.fetch_add(1, Ordering::Relaxed);
});

async move {
tracing::info!("Opened WebSocket connection");
metrics.opened_ws_connections.fetch_add(1, Ordering::Relaxed);
svc.call(req).await
}
.boxed()
} else {
// HTTP.
async move {
tracing::info!("Opened HTTP connection");
metrics.http_calls.fetch_add(1, Ordering::Relaxed);
let rp = svc.call(req).await;

if rp.is_ok() {
metrics.success_http_calls.fetch_add(1, Ordering::Relaxed);
}

tracing::info!("Closed HTTP connection");
rp
}
rp
.boxed()
}
}))
}
Expand Down
2 changes: 1 addition & 1 deletion server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ serde_json = { version = "1", features = ["raw_value"] }
soketto = { version = "0.7.1", features = ["http"] }
tokio = { version = "1.16", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
tokio-stream = "0.1.7"
tokio-stream = { version = "0.1.7", features = ["sync"] }
hyper = { version = "0.14", features = ["server", "http1", "http2"] }
tower = "0.4.13"
thiserror = "1"
Expand Down
40 changes: 39 additions & 1 deletion server/src/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures_util::{Stream, StreamExt};
use futures_util::{Future, Stream, StreamExt};
use pin_project::pin_project;
use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Interval;
use tokio_stream::wrappers::BroadcastStream;

/// Create channel to determine whether
/// the server shall continue to run or not.
Expand Down Expand Up @@ -157,3 +158,40 @@ impl Stream for IntervalStream {
}
}
}

#[derive(Debug, Clone)]
pub(crate) struct SessionClose(tokio::sync::broadcast::Sender<()>);

impl SessionClose {
pub(crate) fn close(self) {
let _ = self.0.send(());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could this be implemented also on Drop? And if we already called close() we'd do nothing on drop, otherwise, we'll call self.0.send()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I skipped it because I'm scared of that as the tower stuff requires Clone and I'm not sure whether something is dropped at some point.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks!

}

pub(crate) fn closed(&self) -> SessionClosedFuture {
SessionClosedFuture(BroadcastStream::new(self.0.subscribe()))
}
}

/// A future that resolves when the connection has been closed.
#[derive(Debug)]
pub struct SessionClosedFuture(BroadcastStream<()>);

impl Future for SessionClosedFuture {
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
// Only message is only sent and
// ignore can't keep up errors.
Poll::Ready(_) => Poll::Ready(()),
}
}
}

pub(crate) fn session_close() -> (SessionClose, SessionClosedFuture) {
// SessionClosedFuture is closed after one message has been recevied
// and max one message is handled then it's closed.
let (tx, rx) = tokio::sync::broadcast::channel(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tokio's API is a bit strange I feel here; you can only create a pair, but then you can also call subscribe() on a tx to get another rx!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yepp, I think the implementation is quite complicated/clever as the Receiver is not clone.

I think it just clone the message(s) to other receivers which is probably quite nice to avoid having a separate state for each receiver

(SessionClose(tx), SessionClosedFuture(BroadcastStream::new(rx)))
}
33 changes: 28 additions & 5 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;

use crate::future::{ConnectionGuard, ServerHandle, StopHandle};
use crate::future::{session_close, ConnectionGuard, ServerHandle, SessionClose, SessionClosedFuture, StopHandle};
use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT};
use crate::transport::ws::BackgroundTaskParams;
use crate::transport::{http, ws};
Expand Down Expand Up @@ -501,6 +501,7 @@ impl<RpcMiddleware, HttpMiddleware> TowerServiceBuilder<RpcMiddleware, HttpMiddl
conn_guard: self.conn_guard,
server_cfg: self.server_cfg,
},
on_session_close: None,
};

TowerService { rpc_middleware, http_middleware: self.http_middleware }
Expand Down Expand Up @@ -617,18 +618,18 @@ impl<HttpMiddleware, RpcMiddleware> Builder<HttpMiddleware, RpcMiddleware> {
/// impl<'a, S> RpcServiceT<'a> for MyMiddleware<S>
/// where S: RpcServiceT<'a> + Send + Sync + Clone + 'static,
/// {
/// type Future = BoxFuture<'a, MethodResponse>;
///
/// type Future = BoxFuture<'a, MethodResponse>;
///
/// fn call(&self, req: Request<'a>) -> Self::Future {
/// tracing::info!("MyMiddleware processed call {}", req.method);
/// let count = self.count.clone();
/// let service = self.service.clone();
/// let service = self.service.clone();
///
/// Box::pin(async move {
/// let rp = service.call(req).await;
/// // Modify the state.
/// count.fetch_add(1, Ordering::Relaxed);
/// rp
/// rp
/// })
/// }
/// }
Expand Down Expand Up @@ -941,6 +942,24 @@ pub struct TowerService<RpcMiddleware, HttpMiddleware> {
http_middleware: tower::ServiceBuilder<HttpMiddleware>,
}

impl<RpcMiddleware, HttpMiddleware> TowerService<RpcMiddleware, HttpMiddleware> {
/// A future that returns when the connection has been closed.
///
/// This method must be called before every [`TowerService::call`]
/// because the `SessionClosedFuture` may already been consumed or
/// not used.
pub fn on_session_closed(&mut self) -> SessionClosedFuture {
if let Some(n) = self.rpc_middleware.on_session_close.as_mut() {
// If it's called more then once another listener is created.
n.closed()
} else {
let (session_close, fut) = session_close();
self.rpc_middleware.on_session_close = Some(session_close);
Copy link
Collaborator

@jsdw jsdw Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any likelihood of a race where the session is closing already or something, and only then you cann on_session_closed and then get back a future that's never called or something?

I don't think it really matters though because why would you subscribe to this so late on

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some extra docs on it, it's possible.

Also it's a little bit weird that TowerService::call can be used several times, in practice I don't think that's the case but a footgun.

That on_session_close on works the subsequent TowerService::call as it does Option::take

fut
}
}
}

impl<RpcMiddleware, HttpMiddleware> hyper::service::Service<hyper::Request<hyper::Body>>
for TowerService<RpcMiddleware, HttpMiddleware>
where
Expand Down Expand Up @@ -979,6 +998,7 @@ where
pub struct TowerServiceNoHttp<L> {
inner: ServiceData,
rpc_middleware: RpcServiceBuilder<L>,
on_session_close: Option<SessionClose>,
}

impl<RpcMiddleware> hyper::service::Service<hyper::Request<hyper::Body>> for TowerServiceNoHttp<RpcMiddleware>
Expand All @@ -1004,6 +1024,7 @@ where
let conn_guard = &self.inner.conn_guard;
let stop_handle = self.inner.stop_handle.clone();
let conn_id = self.inner.conn_id;
let on_session_close = self.on_session_close.take();

tracing::trace!(target: LOG_TARGET, "{:?}", request);

Expand Down Expand Up @@ -1076,6 +1097,7 @@ where
sink,
rx,
pending_calls_completed,
on_session_close,
};

ws::background_task(params).await;
Expand Down Expand Up @@ -1176,6 +1198,7 @@ fn process_connection<'a, RpcMiddleware, HttpMiddleware, U>(
conn_guard: conn_guard.clone(),
},
rpc_middleware,
on_session_close: None,
};

let service = http_middleware.service(tower_service);
Expand Down
71 changes: 69 additions & 2 deletions server/src/tests/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
use std::fmt;
use std::error::Error as StdError;
use std::net::SocketAddr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::{fmt, sync::atomic::AtomicUsize};

use crate::{RpcModule, ServerBuilder, ServerHandle};
use crate::{stop_channel, RpcModule, Server, ServerBuilder, ServerHandle};

use futures_util::FutureExt;
use hyper::server::conn::AddrStream;
use jsonrpsee_core::server::Methods;
use jsonrpsee_core::{DeserializeOwned, RpcResult, StringError};
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::{error::ErrorCode, ErrorObject, ErrorObjectOwned, Response, ResponseSuccess};
use tower::Service;
use tracing_subscriber::{EnvFilter, FmtSubscriber};

pub(crate) struct TestContext;
Expand Down Expand Up @@ -194,3 +201,63 @@ impl From<MyAppError> for ErrorObjectOwned {
fn invalid_params() -> ErrorObjectOwned {
ErrorCode::InvalidParams.into()
}

#[derive(Debug, Clone, Default)]
pub(crate) struct Metrics {
pub(crate) ws_sessions_opened: Arc<AtomicUsize>,
pub(crate) ws_sessions_closed: Arc<AtomicUsize>,
}

pub(crate) fn ws_server_with_stats(metrics: Metrics) -> SocketAddr {
use hyper::service::{make_service_fn, service_fn};

let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let (stop_handle, server_handle) = stop_channel();
let stop_handle2 = stop_handle.clone();

// And a MakeService to handle each connection...
let make_service = make_service_fn(move |_conn: &AddrStream| {
let stop_handle = stop_handle2.clone();
let metrics = metrics.clone();

async move {
Ok::<_, Box<dyn StdError + Send + Sync>>(service_fn(move |req| {
let is_websocket = crate::ws::is_upgrade_request(&req);
let metrics = metrics.clone();
let stop_handle = stop_handle.clone();

let mut svc =
Server::builder().max_connections(33).to_service_builder().build(Methods::new(), stop_handle);

if is_websocket {
// This should work for each callback.
let session_close1 = svc.on_session_closed();
let session_close2 = svc.on_session_closed();

tokio::spawn(async move {
metrics.ws_sessions_opened.fetch_add(1, Ordering::SeqCst);
tokio::join!(session_close2, session_close1);
metrics.ws_sessions_closed.fetch_add(1, Ordering::SeqCst);
});

async move { svc.call(req).await }.boxed()
} else {
// HTTP.
async move { svc.call(req).await }.boxed()
}
}))
}
});

let server = hyper::Server::bind(&addr).serve(make_service);

let addr = server.local_addr();

tokio::spawn(async move {
let graceful = server.with_graceful_shutdown(async move { stop_handle.shutdown().await });
graceful.await.unwrap();
drop(server_handle)
});

addr
}
Loading
Loading