diff --git a/core/src/server/method_response.rs b/core/src/server/method_response.rs index c4bef1222e..f9e134504c 100644 --- a/core/src/server/method_response.rs +++ b/core/src/server/method_response.rs @@ -144,19 +144,6 @@ impl MethodResponse { /// If the serialization of `result` exceeds `max_response_size` then /// the response is changed to an JSON-RPC error object. pub fn response(id: Id, rp: ResponsePayload, max_response_size: usize) -> Self - where - T: Serialize + Clone, - { - Self::response_with_extensions(id, rp, max_response_size, Extensions::new()) - } - - /// Similar to [`MethodResponse::response`] but with extensions. - pub fn response_with_extensions( - id: Id, - rp: ResponsePayload, - max_response_size: usize, - extensions: Extensions, - ) -> Self where T: Serialize + Clone, { @@ -209,7 +196,7 @@ impl MethodResponse { success_or_error: MethodResponseResult::Failed(err.code()), kind, on_close: rp.on_exit, - extensions, + extensions: Extensions::new(), } } } @@ -224,8 +211,8 @@ impl MethodResponse { rp } - /// Similar to [`MethodResponse::error`] but with extensions. - pub fn error_with_extensions<'a>(id: Id, err: impl Into>, extensions: Extensions) -> Self { + /// Create a [`MethodResponse`] from a JSON-RPC error. + pub fn error<'a>(id: Id, err: impl Into>) -> Self { let err: ErrorObject = err.into(); let err_code = err.code(); let err = InnerResponsePayload::<()>::error_borrowed(err); @@ -235,15 +222,10 @@ impl MethodResponse { success_or_error: MethodResponseResult::Failed(err_code), kind: ResponseKind::MethodCall, on_close: None, - extensions, + extensions: Extensions::new(), } } - /// Create a [`MethodResponse`] from a JSON-RPC error. - pub fn error<'a>(id: Id, err: impl Into>) -> Self { - Self::error_with_extensions(id, err, Extensions::new()) - } - /// Returns a reference to the associated extensions. pub fn extensions(&self) -> &Extensions { &self.extensions @@ -253,6 +235,11 @@ impl MethodResponse { pub fn extensions_mut(&mut self) -> &mut Extensions { &mut self.extensions } + + /// Consumes the method response and returns a new one with the given extensions. + pub fn with_extensions(self, extensions: Extensions) -> Self { + Self { extensions, ..self } + } } /// Represent the outcome of a method call success or failed. diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index 541535e12c..d6d0fc643f 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -546,7 +546,7 @@ impl RpcModule { method_name, MethodCallback::Sync(Arc::new(move |id, params, max_response_size, extensions| { let rp = callback(params, &*ctx, &extensions).into_response(); - MethodResponse::response_with_extensions(id, rp, max_response_size, extensions) + MethodResponse::response(id, rp, max_response_size).with_extensions(extensions) })), ) } @@ -580,9 +580,11 @@ impl RpcModule { let ctx = ctx.clone(); let callback = callback.clone(); + // NOTE: the extensions can't be mutated at this point so + // it's safe to clone it. let future = async move { let rp = callback(params, ctx, extensions.clone()).await.into_response(); - MethodResponse::response_with_extensions(id, rp, max_response_size, extensions) + MethodResponse::response(id, rp, max_response_size).with_extensions(extensions) }; future.boxed() })), @@ -600,7 +602,7 @@ impl RpcModule { where Context: Send + Sync + 'static, R: IntoResponse + 'static, - F: Fn(Params, Arc, &Extensions) -> R + Clone + Send + Sync + 'static, + F: Fn(Params, Arc, Extensions) -> R + Clone + Send + Sync + 'static, { let ctx = self.ctx.clone(); let callback = self.methods.verify_and_insert( @@ -609,20 +611,20 @@ impl RpcModule { let ctx = ctx.clone(); let callback = callback.clone(); + // NOTE: the extensions can't be mutated at this point so + // it's safe to clone it. let extensions2 = extensions.clone(); + tokio::task::spawn_blocking(move || { - let rp = callback(params, ctx, &extensions2).into_response(); - MethodResponse::response_with_extensions(id, rp, max_response_size, extensions2) + let rp = callback(params, ctx, extensions2.clone()).into_response(); + MethodResponse::response(id, rp, max_response_size).with_extensions(extensions2) }) .map(|result| match result { Ok(r) => r, Err(err) => { tracing::error!(target: LOG_TARGET, "Join error for blocking RPC method: {:?}", err); - MethodResponse::error_with_extensions( - Id::Null, - ErrorObject::from(ErrorCode::InternalError), - extensions, - ) + MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::InternalError)) + .with_extensions(extensions) } }) .boxed() @@ -774,6 +776,9 @@ impl RpcModule { // definition and not the as same when the subscription call has been completed. // // This runs until the subscription callback has completed. + // + // NOTE: the extensions can't be mutated at this point so + // it's safe to clone it. let sub_fut = callback(params.into_owned(), sink, ctx.clone(), extensions.clone()); tokio::spawn(async move { @@ -800,18 +805,19 @@ impl RpcModule { let id = id.clone().into_owned(); Box::pin(async move { - match rx.await { - Ok(mut rp) => { + let rp = match rx.await { + Ok(rp) => { // If the subscription was accepted then send a message // to subscription task otherwise rely on the drop impl. if rp.is_success() { let _ = accepted_tx.send(()); } - *rp.extensions_mut() = extensions; rp } - Err(_) => MethodResponse::error_with_extensions(id, ErrorCode::InternalError, extensions), - } + Err(_) => MethodResponse::error(id, ErrorCode::InternalError), + }; + + rp.with_extensions(extensions) }) })), )? @@ -905,13 +911,12 @@ impl RpcModule { let id = id.clone().into_owned(); Box::pin(async move { - match rx.await { - Ok(mut rp) => { - *rp.extensions_mut() = extensions; - rp - } - Err(_) => MethodResponse::error_with_extensions(id, ErrorCode::InternalError, extensions), - } + let rp = match rx.await { + Ok(rp) => rp, + Err(_) => MethodResponse::error(id, ErrorCode::InternalError), + }; + + rp.with_extensions(extensions) }) })), )? @@ -953,12 +958,8 @@ impl RpcModule { id ); - return MethodResponse::response_with_extensions( - id, - ResponsePayload::success(false), - max_response_size, - extensions, - ); + return MethodResponse::response(id, ResponsePayload::success(false), max_response_size) + .with_extensions(extensions); } }; diff --git a/examples/examples/rpc_middleware.rs b/examples/examples/rpc_middleware.rs index e35f49a52e..4783257014 100644 --- a/examples/examples/rpc_middleware.rs +++ b/examples/examples/rpc_middleware.rs @@ -43,7 +43,7 @@ use std::sync::Arc; use futures::future::BoxFuture; use futures::FutureExt; -use jsonrpsee::core::{async_trait, client::ClientT}; +use jsonrpsee::core::client::ClientT; use jsonrpsee::rpc_params; use jsonrpsee::server::middleware::rpc::{RpcServiceBuilder, RpcServiceT}; use jsonrpsee::server::{MethodResponse, RpcModule, Server}; @@ -85,7 +85,6 @@ pub struct GlobalCalls { count: Arc, } -#[async_trait] impl<'a, S> RpcServiceT<'a> for GlobalCalls where S: RpcServiceT<'a> + Send + Sync + Clone + 'static, diff --git a/examples/examples/server_with_connection_details.rs b/examples/examples/server_with_connection_details.rs index 966e9f9d4f..3e339a5b4c 100644 --- a/examples/examples/server_with_connection_details.rs +++ b/examples/examples/server_with_connection_details.rs @@ -29,6 +29,7 @@ use std::net::SocketAddr; use jsonrpsee::core::async_trait; use jsonrpsee::core::SubscriptionResult; use jsonrpsee::proc_macros::rpc; +use jsonrpsee::server::middleware::rpc::RpcServiceT; use jsonrpsee::server::{PendingSubscriptionSink, SubscriptionMessage}; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; use jsonrpsee::ws_client::WsClientBuilder; @@ -43,6 +44,22 @@ pub trait Rpc { #[subscription(name = "subscribeConnectionId", item = usize, with_extensions)] async fn sub(&self) -> SubscriptionResult; + + #[subscription(name = "subscribeSyncConnectionId", item = usize, with_extensions)] + fn sub2(&self) -> SubscriptionResult; +} + +struct LoggingMiddleware(S); + +impl<'a, S: RpcServiceT<'a>> RpcServiceT<'a> for LoggingMiddleware { + type Future = S::Future; + + fn call(&self, request: jsonrpsee::types::Request<'a>) -> Self::Future { + tracing::info!("Received request: {:?}", request); + assert!(request.extensions().get::().is_some()); + + self.0.call(request) + } } pub struct RpcServerImpl; @@ -67,6 +84,16 @@ impl RpcServer for RpcServerImpl { sink.send(SubscriptionMessage::from_json(&conn_id).unwrap()).await?; Ok(()) } + + fn sub2(&self, pending: PendingSubscriptionSink, ext: &Extensions) -> SubscriptionResult { + let conn_id = ext.get::().cloned().unwrap(); + + tokio::spawn(async move { + let sink = pending.accept().await.unwrap(); + sink.send(SubscriptionMessage::from_json(&conn_id).unwrap()).await.unwrap(); + }); + Ok(()) + } } #[tokio::main] @@ -100,7 +127,9 @@ async fn main() -> anyhow::Result<()> { } async fn run_server() -> anyhow::Result { - let server = jsonrpsee::server::Server::builder().build("127.0.0.1:0").await?; + let rpc_middleware = jsonrpsee::server::middleware::rpc::RpcServiceBuilder::new().layer_fn(LoggingMiddleware); + + let server = jsonrpsee::server::Server::builder().set_rpc_middleware(rpc_middleware).build("127.0.0.1:0").await?; let addr = server.local_addr()?; let handle = server.start(RpcServerImpl.into_rpc()); diff --git a/server/src/middleware/rpc/layer/rpc_service.rs b/server/src/middleware/rpc/layer/rpc_service.rs index 3b69bbb84a..24d049477c 100644 --- a/server/src/middleware/rpc/layer/rpc_service.rs +++ b/server/src/middleware/rpc/layer/rpc_service.rs @@ -83,15 +83,13 @@ impl<'a> RpcServiceT<'a> for RpcService { let conn_id = self.conn_id; let max_response_body_size = self.max_response_body_size; - let Request { id, method, params, mut extensions, .. } = req; - extensions.insert(conn_id); - + let Request { id, method, params, extensions, .. } = req; let params = jsonrpsee_types::Params::new(params.as_ref().map(|p| serde_json::value::RawValue::get(p))); match self.methods.method_with_name(&method) { None => { - let mut rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)); - *rp.extensions_mut() = extensions; + let rp = + MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)).with_extensions(extensions); ResponseFuture::ready(rp) } Some((_name, method)) => match method { @@ -115,7 +113,8 @@ impl<'a> RpcServiceT<'a> for RpcService { } = self.cfg.clone() else { tracing::warn!("Subscriptions not supported"); - let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)); + let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)) + .with_extensions(extensions); return ResponseFuture::ready(rp); }; @@ -127,7 +126,8 @@ impl<'a> RpcServiceT<'a> for RpcService { ResponseFuture::future(fut) } else { let max = bounded_subscriptions.max(); - let rp = MethodResponse::error(id, reject_too_many_subscriptions(max)); + let rp = + MethodResponse::error(id, reject_too_many_subscriptions(max)).with_extensions(extensions); ResponseFuture::ready(rp) } } @@ -136,7 +136,8 @@ impl<'a> RpcServiceT<'a> for RpcService { let RpcServiceCfg::CallsAndSubscriptions { .. } = self.cfg else { tracing::warn!("Subscriptions not supported"); - let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)); + let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)) + .with_extensions(extensions); return ResponseFuture::ready(rp); }; diff --git a/server/src/server.rs b/server/src/server.rs index 1636bb283d..02aa6ecb4e 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -37,7 +37,8 @@ use crate::future::{session_close, ConnectionGuard, ServerHandle, SessionClose, use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::transport::ws::BackgroundTaskParams; use crate::transport::{http, ws}; -use crate::{HttpBody, HttpRequest, HttpResponse, LOG_TARGET}; +use crate::utils::deserialize; +use crate::{Extensions, HttpBody, HttpRequest, HttpResponse, LOG_TARGET}; use futures_util::future::{self, Either, FutureExt}; use futures_util::io::{BufReader, BufWriter}; @@ -46,14 +47,16 @@ use hyper::body::Bytes; use hyper_util::rt::{TokioExecutor, TokioIo}; use jsonrpsee_core::id_providers::RandomIntegerIdProvider; use jsonrpsee_core::server::helpers::prepare_error; -use jsonrpsee_core::server::{BatchResponseBuilder, BoundedSubscriptions, MethodResponse, MethodSink, Methods}; +use jsonrpsee_core::server::{ + BatchResponseBuilder, BoundedSubscriptions, ConnectionId, MethodResponse, MethodSink, Methods, +}; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_core::{BoxError, JsonRawValue, TEN_MB_SIZE_BYTES}; use jsonrpsee_types::error::{ reject_too_big_batch_request, ErrorCode, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, }; -use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification, Request}; +use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification}; use soketto::handshake::http::is_upgrade_request; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, watch, OwnedSemaphorePermit}; @@ -1053,7 +1056,7 @@ where } fn call(&mut self, request: HttpRequest) -> Self::Future { - let request = request.map(HttpBody::new); + let mut request = request.map(HttpBody::new); let conn_guard = &self.inner.conn_guard; let stop_handle = self.inner.stop_handle.clone(); @@ -1072,6 +1075,8 @@ where let curr_conns = max_conns - conn_guard.available_connections(); tracing::debug!(target: LOG_TARGET, "Accepting new connection {}/{}", curr_conns, max_conns); + request.extensions_mut().insert::(conn.conn_id.into()); + let is_upgrade_request = is_upgrade_request(&request); if self.inner.server_cfg.enable_ws && is_upgrade_request { @@ -1109,6 +1114,8 @@ where tokio::spawn( async move { + let extensions = request.extensions().clone(); + let upgraded = match hyper::upgrade::on(request).await { Ok(u) => u, Err(e) => { @@ -1134,6 +1141,7 @@ where rx, pending_calls_completed, on_session_close, + extensions, }; ws::background_task(params).await; @@ -1288,13 +1296,14 @@ pub(crate) async fn handle_rpc_call( batch_config: BatchRequestConfig, max_response_size: u32, rpc_service: &S, + extensions: Extensions, ) -> Option where for<'a> S: RpcServiceT<'a> + Send, { // Single request or notification if is_single { - if let Ok(req) = serde_json::from_slice(body) { + if let Ok(req) = deserialize::from_slice_with_extensions(body, extensions) { Some(rpc_service.call(req).await) } else if let Ok(_notif) = serde_json::from_slice::(body) { None @@ -1326,7 +1335,7 @@ where let mut batch_response = BatchResponseBuilder::new_with_limit(max_response_size as usize); for call in batch { - if let Ok(req) = serde_json::from_str::(call.get()) { + if let Ok(req) = deserialize::from_str_with_extensions(call.get(), extensions.clone()) { let rp = rpc_service.call(req).await; if let Err(too_large) = batch_response.append(&rp) { diff --git a/server/src/transport/http.rs b/server/src/transport/http.rs index 022c4dce82..2a730101cd 100644 --- a/server/src/transport/http.rs +++ b/server/src/transport/http.rs @@ -95,7 +95,8 @@ where } }; - let rp = handle_rpc_call(&body, is_single, batch_config, max_response_size, &rpc_service).await; + let rp = handle_rpc_call(&body, is_single, batch_config, max_response_size, &rpc_service, parts.extensions) + .await; // If the response is empty it means that it was a notification or empty batch. // For HTTP these are just ACK:ed with a empty body. diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index bcf3ee192a..9fc24cdccf 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -57,6 +57,7 @@ pub(crate) struct BackgroundTaskParams { pub(crate) rx: mpsc::Receiver, pub(crate) pending_calls_completed: mpsc::Receiver<()>, pub(crate) on_session_close: Option, + pub(crate) extensions: http::Extensions, } pub(crate) async fn background_task(params: BackgroundTaskParams) @@ -73,6 +74,7 @@ where rx, pending_calls_completed, mut on_session_close, + extensions, } = params; let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; @@ -141,6 +143,7 @@ where let rpc_service = rpc_service.clone(); let sink = sink.clone(); + let extensions = extensions.clone(); tokio::spawn(async move { let first_non_whitespace = data.iter().enumerate().take(128).find(|(_, byte)| !byte.is_ascii_whitespace()); @@ -154,9 +157,15 @@ where } }; - if let Some(rp) = - handle_rpc_call(&data[idx..], is_single, batch_requests_config, max_response_body_size, &*rpc_service) - .await + if let Some(rp) = handle_rpc_call( + &data[idx..], + is_single, + batch_requests_config, + max_response_body_size, + &*rpc_service, + extensions, + ) + .await { if !rp.is_subscription() { let is_success = rp.is_success(); @@ -426,6 +435,8 @@ where match server.receive_request(&req) { Ok(response) => { + let extensions = req.extensions().clone(); + let upgraded = match hyper::upgrade::on(req).await { Ok(u) => u, Err(e) => { @@ -477,6 +488,7 @@ where rx, pending_calls_completed, on_session_close: None, + extensions, }; background_task(params).await; diff --git a/server/src/utils.rs b/server/src/utils.rs index 93214e5a70..0337c65d65 100644 --- a/server/src/utils.rs +++ b/server/src/utils.rs @@ -78,3 +78,26 @@ where self.project().future.poll(cx) } } + +/// Helpers to deserialize a request with extensions. +pub(crate) mod deserialize { + /// Helper to deserialize a request with extensions. + pub(crate) fn from_slice_with_extensions( + data: &[u8], + extensions: http::Extensions, + ) -> Result { + let mut req: jsonrpsee_types::Request = serde_json::from_slice(data)?; + *req.extensions_mut() = extensions; + Ok(req) + } + + /// Helper to deserialize a request with extensions. + pub(crate) fn from_str_with_extensions( + data: &str, + extensions: http::Extensions, + ) -> Result { + let mut req: jsonrpsee_types::Request = serde_json::from_str(data)?; + *req.extensions_mut() = extensions; + Ok(req) + } +} diff --git a/tests/tests/metrics.rs b/tests/tests/metrics.rs index 2a308f3578..c8cef4d626 100644 --- a/tests/tests/metrics.rs +++ b/tests/tests/metrics.rs @@ -37,7 +37,7 @@ use std::time::Duration; use futures::future::BoxFuture; use futures::FutureExt; use helpers::init_logger; -use jsonrpsee::core::{async_trait, client::ClientT, ClientError}; +use jsonrpsee::core::{client::ClientT, ClientError}; use jsonrpsee::http_client::HttpClientBuilder; use jsonrpsee::proc_macros::rpc; use jsonrpsee::server::middleware::rpc::{RpcServiceBuilder, RpcServiceT}; @@ -62,7 +62,6 @@ pub struct CounterMiddleware { counter: Arc>, } -#[async_trait] impl<'a, S> RpcServiceT<'a> for CounterMiddleware where S: RpcServiceT<'a> + Send + Sync + Clone + 'static,