Skip to content

Commit

Permalink
server: pass extensions from http layer (#1389)
Browse files Browse the repository at this point in the history
* server: pass `extensions` from http layer

* fix nits

* remove needless clone

* fix nits

* remove needless async_trait
  • Loading branch information
niklasad1 authored Jun 5, 2024
1 parent 236a561 commit 038a77f
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 73 deletions.
31 changes: 9 additions & 22 deletions core/src/server/method_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,6 @@ impl MethodResponse {
/// If the serialization of `result` exceeds `max_response_size` then
/// the response is changed to an JSON-RPC error object.
pub fn response<T>(id: Id, rp: ResponsePayload<T>, max_response_size: usize) -> Self
where
T: Serialize + Clone,
{
Self::response_with_extensions(id, rp, max_response_size, Extensions::new())
}

/// Similar to [`MethodResponse::response`] but with extensions.
pub fn response_with_extensions<T>(
id: Id,
rp: ResponsePayload<T>,
max_response_size: usize,
extensions: Extensions,
) -> Self
where
T: Serialize + Clone,
{
Expand Down Expand Up @@ -209,7 +196,7 @@ impl MethodResponse {
success_or_error: MethodResponseResult::Failed(err.code()),
kind,
on_close: rp.on_exit,
extensions,
extensions: Extensions::new(),
}
}
}
Expand All @@ -224,8 +211,8 @@ impl MethodResponse {
rp
}

/// Similar to [`MethodResponse::error`] but with extensions.
pub fn error_with_extensions<'a>(id: Id, err: impl Into<ErrorObject<'a>>, extensions: Extensions) -> Self {
/// Create a [`MethodResponse`] from a JSON-RPC error.
pub fn error<'a>(id: Id, err: impl Into<ErrorObject<'a>>) -> Self {
let err: ErrorObject = err.into();
let err_code = err.code();
let err = InnerResponsePayload::<()>::error_borrowed(err);
Expand All @@ -235,15 +222,10 @@ impl MethodResponse {
success_or_error: MethodResponseResult::Failed(err_code),
kind: ResponseKind::MethodCall,
on_close: None,
extensions,
extensions: Extensions::new(),
}
}

/// Create a [`MethodResponse`] from a JSON-RPC error.
pub fn error<'a>(id: Id, err: impl Into<ErrorObject<'a>>) -> Self {
Self::error_with_extensions(id, err, Extensions::new())
}

/// Returns a reference to the associated extensions.
pub fn extensions(&self) -> &Extensions {
&self.extensions
Expand All @@ -253,6 +235,11 @@ impl MethodResponse {
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}

/// Consumes the method response and returns a new one with the given extensions.
pub fn with_extensions(self, extensions: Extensions) -> Self {
Self { extensions, ..self }
}
}

/// Represent the outcome of a method call success or failed.
Expand Down
57 changes: 29 additions & 28 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
method_name,
MethodCallback::Sync(Arc::new(move |id, params, max_response_size, extensions| {
let rp = callback(params, &*ctx, &extensions).into_response();
MethodResponse::response_with_extensions(id, rp, max_response_size, extensions)
MethodResponse::response(id, rp, max_response_size).with_extensions(extensions)
})),
)
}
Expand Down Expand Up @@ -580,9 +580,11 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = ctx.clone();
let callback = callback.clone();

// NOTE: the extensions can't be mutated at this point so
// it's safe to clone it.
let future = async move {
let rp = callback(params, ctx, extensions.clone()).await.into_response();
MethodResponse::response_with_extensions(id, rp, max_response_size, extensions)
MethodResponse::response(id, rp, max_response_size).with_extensions(extensions)
};
future.boxed()
})),
Expand All @@ -600,7 +602,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
where
Context: Send + Sync + 'static,
R: IntoResponse + 'static,
F: Fn(Params, Arc<Context>, &Extensions) -> R + Clone + Send + Sync + 'static,
F: Fn(Params, Arc<Context>, Extensions) -> R + Clone + Send + Sync + 'static,
{
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
Expand All @@ -609,20 +611,20 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = ctx.clone();
let callback = callback.clone();

// NOTE: the extensions can't be mutated at this point so
// it's safe to clone it.
let extensions2 = extensions.clone();

tokio::task::spawn_blocking(move || {
let rp = callback(params, ctx, &extensions2).into_response();
MethodResponse::response_with_extensions(id, rp, max_response_size, extensions2)
let rp = callback(params, ctx, extensions2.clone()).into_response();
MethodResponse::response(id, rp, max_response_size).with_extensions(extensions2)
})
.map(|result| match result {
Ok(r) => r,
Err(err) => {
tracing::error!(target: LOG_TARGET, "Join error for blocking RPC method: {:?}", err);
MethodResponse::error_with_extensions(
Id::Null,
ErrorObject::from(ErrorCode::InternalError),
extensions,
)
MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::InternalError))
.with_extensions(extensions)
}
})
.boxed()
Expand Down Expand Up @@ -774,6 +776,9 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
// definition and not the as same when the subscription call has been completed.
//
// This runs until the subscription callback has completed.
//
// NOTE: the extensions can't be mutated at this point so
// it's safe to clone it.
let sub_fut = callback(params.into_owned(), sink, ctx.clone(), extensions.clone());

tokio::spawn(async move {
Expand All @@ -800,18 +805,19 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let id = id.clone().into_owned();

Box::pin(async move {
match rx.await {
Ok(mut rp) => {
let rp = match rx.await {
Ok(rp) => {
// If the subscription was accepted then send a message
// to subscription task otherwise rely on the drop impl.
if rp.is_success() {
let _ = accepted_tx.send(());
}
*rp.extensions_mut() = extensions;
rp
}
Err(_) => MethodResponse::error_with_extensions(id, ErrorCode::InternalError, extensions),
}
Err(_) => MethodResponse::error(id, ErrorCode::InternalError),
};

rp.with_extensions(extensions)
})
})),
)?
Expand Down Expand Up @@ -905,13 +911,12 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let id = id.clone().into_owned();

Box::pin(async move {
match rx.await {
Ok(mut rp) => {
*rp.extensions_mut() = extensions;
rp
}
Err(_) => MethodResponse::error_with_extensions(id, ErrorCode::InternalError, extensions),
}
let rp = match rx.await {
Ok(rp) => rp,
Err(_) => MethodResponse::error(id, ErrorCode::InternalError),
};

rp.with_extensions(extensions)
})
})),
)?
Expand Down Expand Up @@ -953,12 +958,8 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
id
);

return MethodResponse::response_with_extensions(
id,
ResponsePayload::success(false),
max_response_size,
extensions,
);
return MethodResponse::response(id, ResponsePayload::success(false), max_response_size)
.with_extensions(extensions);
}
};

Expand Down
3 changes: 1 addition & 2 deletions examples/examples/rpc_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use std::sync::Arc;

use futures::future::BoxFuture;
use futures::FutureExt;
use jsonrpsee::core::{async_trait, client::ClientT};
use jsonrpsee::core::client::ClientT;
use jsonrpsee::rpc_params;
use jsonrpsee::server::middleware::rpc::{RpcServiceBuilder, RpcServiceT};
use jsonrpsee::server::{MethodResponse, RpcModule, Server};
Expand Down Expand Up @@ -85,7 +85,6 @@ pub struct GlobalCalls<S> {
count: Arc<AtomicUsize>,
}

#[async_trait]
impl<'a, S> RpcServiceT<'a> for GlobalCalls<S>
where
S: RpcServiceT<'a> + Send + Sync + Clone + 'static,
Expand Down
31 changes: 30 additions & 1 deletion examples/examples/server_with_connection_details.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use std::net::SocketAddr;
use jsonrpsee::core::async_trait;
use jsonrpsee::core::SubscriptionResult;
use jsonrpsee::proc_macros::rpc;
use jsonrpsee::server::middleware::rpc::RpcServiceT;
use jsonrpsee::server::{PendingSubscriptionSink, SubscriptionMessage};
use jsonrpsee::types::{ErrorObject, ErrorObjectOwned};
use jsonrpsee::ws_client::WsClientBuilder;
Expand All @@ -43,6 +44,22 @@ pub trait Rpc {

#[subscription(name = "subscribeConnectionId", item = usize, with_extensions)]
async fn sub(&self) -> SubscriptionResult;

#[subscription(name = "subscribeSyncConnectionId", item = usize, with_extensions)]
fn sub2(&self) -> SubscriptionResult;
}

struct LoggingMiddleware<S>(S);

impl<'a, S: RpcServiceT<'a>> RpcServiceT<'a> for LoggingMiddleware<S> {
type Future = S::Future;

fn call(&self, request: jsonrpsee::types::Request<'a>) -> Self::Future {
tracing::info!("Received request: {:?}", request);
assert!(request.extensions().get::<ConnectionId>().is_some());

self.0.call(request)
}
}

pub struct RpcServerImpl;
Expand All @@ -67,6 +84,16 @@ impl RpcServer for RpcServerImpl {
sink.send(SubscriptionMessage::from_json(&conn_id).unwrap()).await?;
Ok(())
}

fn sub2(&self, pending: PendingSubscriptionSink, ext: &Extensions) -> SubscriptionResult {
let conn_id = ext.get::<ConnectionId>().cloned().unwrap();

tokio::spawn(async move {
let sink = pending.accept().await.unwrap();
sink.send(SubscriptionMessage::from_json(&conn_id).unwrap()).await.unwrap();
});
Ok(())
}
}

#[tokio::main]
Expand Down Expand Up @@ -100,7 +127,9 @@ async fn main() -> anyhow::Result<()> {
}

async fn run_server() -> anyhow::Result<SocketAddr> {
let server = jsonrpsee::server::Server::builder().build("127.0.0.1:0").await?;
let rpc_middleware = jsonrpsee::server::middleware::rpc::RpcServiceBuilder::new().layer_fn(LoggingMiddleware);

let server = jsonrpsee::server::Server::builder().set_rpc_middleware(rpc_middleware).build("127.0.0.1:0").await?;
let addr = server.local_addr()?;

let handle = server.start(RpcServerImpl.into_rpc());
Expand Down
17 changes: 9 additions & 8 deletions server/src/middleware/rpc/layer/rpc_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,13 @@ impl<'a> RpcServiceT<'a> for RpcService {
let conn_id = self.conn_id;
let max_response_body_size = self.max_response_body_size;

let Request { id, method, params, mut extensions, .. } = req;
extensions.insert(conn_id);

let Request { id, method, params, extensions, .. } = req;
let params = jsonrpsee_types::Params::new(params.as_ref().map(|p| serde_json::value::RawValue::get(p)));

match self.methods.method_with_name(&method) {
None => {
let mut rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound));
*rp.extensions_mut() = extensions;
let rp =
MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)).with_extensions(extensions);
ResponseFuture::ready(rp)
}
Some((_name, method)) => match method {
Expand All @@ -115,7 +113,8 @@ impl<'a> RpcServiceT<'a> for RpcService {
} = self.cfg.clone()
else {
tracing::warn!("Subscriptions not supported");
let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError));
let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError))
.with_extensions(extensions);
return ResponseFuture::ready(rp);
};

Expand All @@ -127,7 +126,8 @@ impl<'a> RpcServiceT<'a> for RpcService {
ResponseFuture::future(fut)
} else {
let max = bounded_subscriptions.max();
let rp = MethodResponse::error(id, reject_too_many_subscriptions(max));
let rp =
MethodResponse::error(id, reject_too_many_subscriptions(max)).with_extensions(extensions);
ResponseFuture::ready(rp)
}
}
Expand All @@ -136,7 +136,8 @@ impl<'a> RpcServiceT<'a> for RpcService {

let RpcServiceCfg::CallsAndSubscriptions { .. } = self.cfg else {
tracing::warn!("Subscriptions not supported");
let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError));
let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError))
.with_extensions(extensions);
return ResponseFuture::ready(rp);
};

Expand Down
Loading

0 comments on commit 038a77f

Please sign in to comment.