diff --git a/types/src/error.rs b/types/src/error.rs index 607e193fde..f502c27b11 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -66,6 +66,9 @@ pub enum Error { /// Method was already registered. #[error("Method: {0} was already registered")] MethodAlreadyRegistered(String), + /// Method with that name has not yet been registered. + #[error("Method: {0} has not yet been registered")] + MethodNotFound(String), /// Subscribe and unsubscribe method names are the same. #[error("Cannot use the same method name for subscribe and unsubscribe, used: {0}")] SubscriptionNameConflict(String), diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index 602e724fd8..bd4f1c7e44 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -18,7 +18,7 @@ use std::sync::Arc; /// implemented as a function pointer to a `Fn` function taking four arguments: /// the `id`, `params`, a channel the function uses to communicate the result (or error) /// back to `jsonrpsee`, and the connection ID (useful for the websocket transport). -pub type SyncMethod = Box Result<(), Error>>; +pub type SyncMethod = Arc Result<(), Error>>; /// Similar to [`SyncMethod`], but represents an asynchronous handler. pub type AsyncMethod = Arc< dyn Send + Sync + Fn(OwnedId, OwnedRpcParams, MethodSink, ConnectionId) -> BoxFuture<'static, Result<(), Error>>, @@ -41,6 +41,7 @@ struct SubscriptionKey { } /// Callback wrapper that can be either sync or async. +#[derive(Clone)] pub enum MethodCallback { /// Synchronous method handler. Sync(SyncMethod), @@ -81,10 +82,10 @@ impl Debug for MethodCallback { } } -/// Collection of synchronous and asynchronous methods. -#[derive(Default, Debug)] +/// Reference-counted, clone-on-write collection of synchronous and asynchronous methods. +#[derive(Default, Debug, Clone)] pub struct Methods { - callbacks: FxHashMap<&'static str, MethodCallback>, + callbacks: Arc>, } impl Methods { @@ -101,15 +102,22 @@ impl Methods { Ok(()) } + /// Helper for obtaining a mut ref to the callbacks HashMap. + fn mut_callbacks(&mut self) -> &mut FxHashMap<&'static str, MethodCallback> { + Arc::make_mut(&mut self.callbacks) + } + /// Merge two [`Methods`]'s by adding all [`MethodCallback`]s from `other` into `self`. /// Fails if any of the methods in `other` is present already. - pub fn merge(&mut self, other: Methods) -> Result<(), Error> { + pub fn merge(&mut self, mut other: Methods) -> Result<(), Error> { for name in other.callbacks.keys() { self.verify_method_name(name)?; } - for (name, callback) in other.callbacks { - self.callbacks.insert(name, callback); + let callbacks = self.mut_callbacks(); + + for (name, callback) in other.mut_callbacks().drain() { + callbacks.insert(name, callback); } Ok(()) @@ -137,17 +145,33 @@ impl Methods { /// Sets of JSON-RPC methods can be organized into a "module"s that are in turn registered on the server or, /// alternatively, merged with other modules to construct a cohesive API. [`RpcModule`] wraps an additional context /// argument that can be used to access data during call execution. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct RpcModule { ctx: Arc, methods: Methods, } -impl RpcModule { +impl RpcModule { /// Create a new module with a given shared `Context`. pub fn new(ctx: Context) -> Self { Self { ctx: Arc::new(ctx), methods: Default::default() } } + + /// Convert a module into methods. Consumes self. + pub fn into_methods(self) -> Methods { + self.methods + } + + /// Merge two [`RpcModule`]'s by adding all [`Methods`] `other` into `self`. + /// Fails if any of the methods in `other` is present already. + pub fn merge(&mut self, other: RpcModule) -> Result<(), Error> { + self.methods.merge(other.methods)?; + + Ok(()) + } +} + +impl RpcModule { /// Register a new synchronous RPC method, which computes the response with the given callback. pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> where @@ -159,9 +183,9 @@ impl RpcModule { let ctx = self.ctx.clone(); - self.methods.callbacks.insert( + self.methods.mut_callbacks().insert( method_name, - MethodCallback::Sync(Box::new(move |id, params, tx, _| { + MethodCallback::Sync(Arc::new(move |id, params, tx, _| { match callback(params, &*ctx) { Ok(res) => send_response(id, tx, res), Err(CallError::InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), @@ -192,7 +216,7 @@ impl RpcModule { let ctx = self.ctx.clone(); - self.methods.callbacks.insert( + self.methods.mut_callbacks().insert( method_name, MethodCallback::Async(Arc::new(move |id, params, tx, _| { let ctx = ctx.clone(); @@ -265,9 +289,9 @@ impl RpcModule { { let subscribers = subscribers.clone(); - self.methods.callbacks.insert( + self.methods.mut_callbacks().insert( subscribe_method_name, - MethodCallback::Sync(Box::new(move |id, params, method_sink, conn_id| { + MethodCallback::Sync(Arc::new(move |id, params, method_sink, conn_id| { let (conn_tx, conn_rx) = oneshot::channel::<()>(); let sub_id = { const JS_NUM_MASK: SubscriptionId = !0 >> 11; @@ -293,9 +317,9 @@ impl RpcModule { } { - self.methods.callbacks.insert( + self.methods.mut_callbacks().insert( unsubscribe_method_name, - MethodCallback::Sync(Box::new(move |id, params, tx, conn_id| { + MethodCallback::Sync(Arc::new(move |id, params, tx, conn_id| { let sub_id = params.one()?; subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id }); send_response(id, &tx, "Unsubscribed"); @@ -308,15 +332,16 @@ impl RpcModule { Ok(()) } - /// Convert a module into methods. Consumes self. - pub fn into_methods(self) -> Methods { - self.methods - } + /// Register an `alias` name for an `existing_method`. + pub fn register_alias(&mut self, alias: &'static str, existing_method: &'static str) -> Result<(), Error> { + self.methods.verify_method_name(alias)?; - /// Merge two [`RpcModule`]'s by adding all [`Methods`] `other` into `self`. - /// Fails if any of the methods in `other` is present already. - pub fn merge(&mut self, other: RpcModule) -> Result<(), Error> { - self.methods.merge(other.methods)?; + let callback = match self.methods.callbacks.get(existing_method) { + Some(callback) => callback.clone(), + None => return Err(Error::MethodNotFound(existing_method.into())), + }; + + self.methods.mut_callbacks().insert(alias, callback); Ok(()) } @@ -431,4 +456,17 @@ mod tests { assert!(methods.method("hi").is_some()); assert!(methods.method("goodbye").is_some()); } + + #[test] + fn rpc_register_alias() { + let mut module = RpcModule::new(()); + + module.register_method("hello_world", |_: RpcParams, _| Ok(())).unwrap(); + module.register_alias("hello_foobar", "hello_world").unwrap(); + + let methods = module.into_methods(); + + assert!(methods.method("hello_world").is_some()); + assert!(methods.method("hello_foobar").is_some()); + } } diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index e5f0b7efe4..62f82daac4 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -56,7 +56,7 @@ impl Server { /// Register all methods from a [`Methods`] of provided [`RpcModule`] on this server. /// In case a method already is registered with the same name, no method is added and a [`Error::MethodAlreadyRegistered`] /// is returned. Note that the [`RpcModule`] is consumed after this call. - pub fn register_module(&mut self, module: RpcModule) -> Result<(), Error> { + pub fn register_module(&mut self, module: RpcModule) -> Result<(), Error> { self.methods.merge(module.into_methods())?; Ok(()) } @@ -74,7 +74,8 @@ impl Server { /// Start responding to connections requests. This will block current thread until the server is stopped. pub async fn start(self) { let mut incoming = TcpListenerStream::new(self.listener); - let methods = Arc::new(self.methods); + let methods = self.methods; + let conn_counter = Arc::new(()); let cfg = self.cfg; let mut id = 0; @@ -82,13 +83,18 @@ impl Server { if let Ok(socket) = socket { socket.set_nodelay(true).unwrap_or_else(|e| panic!("Could not set NODELAY on socket: {:?}", e)); - if Arc::strong_count(&methods) > self.cfg.max_connections as usize { + if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize { log::warn!("Too many connections. Try again in a while"); continue; } let methods = methods.clone(); + let counter = conn_counter.clone(); - tokio::spawn(background_task(socket, id, methods, cfg)); + tokio::spawn(async move { + let r = background_task(socket, id, methods, cfg).await; + drop(counter); + r + }); id += 1; } @@ -99,7 +105,7 @@ impl Server { async fn background_task( socket: tokio::net::TcpStream, conn_id: ConnectionId, - methods: Arc, + methods: Methods, cfg: Settings, ) -> Result<(), Error> { // For each incoming background_task we perform a handshake.