Skip to content

Commit

Permalink
Method aliases + RpcModule: Clone (#383)
Browse files Browse the repository at this point in the history
* Make sync methods into Arc pointers

* impl Clone for RpcModule and Methods

* No need to wrap Methods in Arc anymore

* Simplify generics

* register_alias

* fmt

* grammar

Co-authored-by: James Wilson <james@jsdw.me>

* Use a separate Arc counter for tracking max_connections

Co-authored-by: James Wilson <james@jsdw.me>
  • Loading branch information
maciejhirsz and jsdw authored Jun 18, 2021
1 parent 82b1614 commit 6c69a8c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 29 deletions.
3 changes: 3 additions & 0 deletions types/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ pub enum Error {
/// Method was already registered.
#[error("Method: {0} was already registered")]
MethodAlreadyRegistered(String),
/// Method with that name has not yet been registered.
#[error("Method: {0} has not yet been registered")]
MethodNotFound(String),
/// Subscribe and unsubscribe method names are the same.
#[error("Cannot use the same method name for subscribe and unsubscribe, used: {0}")]
SubscriptionNameConflict(String),
Expand Down
86 changes: 62 additions & 24 deletions utils/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::sync::Arc;
/// implemented as a function pointer to a `Fn` function taking four arguments:
/// the `id`, `params`, a channel the function uses to communicate the result (or error)
/// back to `jsonrpsee`, and the connection ID (useful for the websocket transport).
pub type SyncMethod = Box<dyn Send + Sync + Fn(Id, RpcParams, &MethodSink, ConnectionId) -> Result<(), Error>>;
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, RpcParams, &MethodSink, ConnectionId) -> Result<(), Error>>;
/// Similar to [`SyncMethod`], but represents an asynchronous handler.
pub type AsyncMethod = Arc<
dyn Send + Sync + Fn(OwnedId, OwnedRpcParams, MethodSink, ConnectionId) -> BoxFuture<'static, Result<(), Error>>,
Expand All @@ -41,6 +41,7 @@ struct SubscriptionKey {
}

/// Callback wrapper that can be either sync or async.
#[derive(Clone)]
pub enum MethodCallback {
/// Synchronous method handler.
Sync(SyncMethod),
Expand Down Expand Up @@ -81,10 +82,10 @@ impl Debug for MethodCallback {
}
}

/// Collection of synchronous and asynchronous methods.
#[derive(Default, Debug)]
/// Reference-counted, clone-on-write collection of synchronous and asynchronous methods.
#[derive(Default, Debug, Clone)]
pub struct Methods {
callbacks: FxHashMap<&'static str, MethodCallback>,
callbacks: Arc<FxHashMap<&'static str, MethodCallback>>,
}

impl Methods {
Expand All @@ -101,15 +102,22 @@ impl Methods {
Ok(())
}

/// Helper for obtaining a mut ref to the callbacks HashMap.
fn mut_callbacks(&mut self) -> &mut FxHashMap<&'static str, MethodCallback> {
Arc::make_mut(&mut self.callbacks)
}

/// Merge two [`Methods`]'s by adding all [`MethodCallback`]s from `other` into `self`.
/// Fails if any of the methods in `other` is present already.
pub fn merge(&mut self, other: Methods) -> Result<(), Error> {
pub fn merge(&mut self, mut other: Methods) -> Result<(), Error> {
for name in other.callbacks.keys() {
self.verify_method_name(name)?;
}

for (name, callback) in other.callbacks {
self.callbacks.insert(name, callback);
let callbacks = self.mut_callbacks();

for (name, callback) in other.mut_callbacks().drain() {
callbacks.insert(name, callback);
}

Ok(())
Expand Down Expand Up @@ -137,17 +145,33 @@ impl Methods {
/// Sets of JSON-RPC methods can be organized into a "module"s that are in turn registered on the server or,
/// alternatively, merged with other modules to construct a cohesive API. [`RpcModule`] wraps an additional context
/// argument that can be used to access data during call execution.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct RpcModule<Context> {
ctx: Arc<Context>,
methods: Methods,
}

impl<Context: Send + Sync + 'static> RpcModule<Context> {
impl<Context> RpcModule<Context> {
/// Create a new module with a given shared `Context`.
pub fn new(ctx: Context) -> Self {
Self { ctx: Arc::new(ctx), methods: Default::default() }
}

/// Convert a module into methods. Consumes self.
pub fn into_methods(self) -> Methods {
self.methods
}

/// Merge two [`RpcModule`]'s by adding all [`Methods`] `other` into `self`.
/// Fails if any of the methods in `other` is present already.
pub fn merge<Context2>(&mut self, other: RpcModule<Context2>) -> Result<(), Error> {
self.methods.merge(other.methods)?;

Ok(())
}
}

impl<Context: Send + Sync + 'static> RpcModule<Context> {
/// Register a new synchronous RPC method, which computes the response with the given callback.
pub fn register_method<R, F>(&mut self, method_name: &'static str, callback: F) -> Result<(), Error>
where
Expand All @@ -159,9 +183,9 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

let ctx = self.ctx.clone();

self.methods.callbacks.insert(
self.methods.mut_callbacks().insert(
method_name,
MethodCallback::Sync(Box::new(move |id, params, tx, _| {
MethodCallback::Sync(Arc::new(move |id, params, tx, _| {
match callback(params, &*ctx) {
Ok(res) => send_response(id, tx, res),
Err(CallError::InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()),
Expand Down Expand Up @@ -192,7 +216,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

let ctx = self.ctx.clone();

self.methods.callbacks.insert(
self.methods.mut_callbacks().insert(
method_name,
MethodCallback::Async(Arc::new(move |id, params, tx, _| {
let ctx = ctx.clone();
Expand Down Expand Up @@ -265,9 +289,9 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

{
let subscribers = subscribers.clone();
self.methods.callbacks.insert(
self.methods.mut_callbacks().insert(
subscribe_method_name,
MethodCallback::Sync(Box::new(move |id, params, method_sink, conn_id| {
MethodCallback::Sync(Arc::new(move |id, params, method_sink, conn_id| {
let (conn_tx, conn_rx) = oneshot::channel::<()>();
let sub_id = {
const JS_NUM_MASK: SubscriptionId = !0 >> 11;
Expand All @@ -293,9 +317,9 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
}

{
self.methods.callbacks.insert(
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
MethodCallback::Sync(Box::new(move |id, params, tx, conn_id| {
MethodCallback::Sync(Arc::new(move |id, params, tx, conn_id| {
let sub_id = params.one()?;
subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id });
send_response(id, &tx, "Unsubscribed");
Expand All @@ -308,15 +332,16 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
Ok(())
}

/// Convert a module into methods. Consumes self.
pub fn into_methods(self) -> Methods {
self.methods
}
/// Register an `alias` name for an `existing_method`.
pub fn register_alias(&mut self, alias: &'static str, existing_method: &'static str) -> Result<(), Error> {
self.methods.verify_method_name(alias)?;

/// Merge two [`RpcModule`]'s by adding all [`Methods`] `other` into `self`.
/// Fails if any of the methods in `other` is present already.
pub fn merge<Context2>(&mut self, other: RpcModule<Context2>) -> Result<(), Error> {
self.methods.merge(other.methods)?;
let callback = match self.methods.callbacks.get(existing_method) {
Some(callback) => callback.clone(),
None => return Err(Error::MethodNotFound(existing_method.into())),
};

self.methods.mut_callbacks().insert(alias, callback);

Ok(())
}
Expand Down Expand Up @@ -431,4 +456,17 @@ mod tests {
assert!(methods.method("hi").is_some());
assert!(methods.method("goodbye").is_some());
}

#[test]
fn rpc_register_alias() {
let mut module = RpcModule::new(());

module.register_method("hello_world", |_: RpcParams, _| Ok(())).unwrap();
module.register_alias("hello_foobar", "hello_world").unwrap();

let methods = module.into_methods();

assert!(methods.method("hello_world").is_some());
assert!(methods.method("hello_foobar").is_some());
}
}
16 changes: 11 additions & 5 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl Server {
/// Register all methods from a [`Methods`] of provided [`RpcModule`] on this server.
/// In case a method already is registered with the same name, no method is added and a [`Error::MethodAlreadyRegistered`]
/// is returned. Note that the [`RpcModule`] is consumed after this call.
pub fn register_module<Context: Send + Sync + 'static>(&mut self, module: RpcModule<Context>) -> Result<(), Error> {
pub fn register_module<Context>(&mut self, module: RpcModule<Context>) -> Result<(), Error> {
self.methods.merge(module.into_methods())?;
Ok(())
}
Expand All @@ -74,21 +74,27 @@ impl Server {
/// Start responding to connections requests. This will block current thread until the server is stopped.
pub async fn start(self) {
let mut incoming = TcpListenerStream::new(self.listener);
let methods = Arc::new(self.methods);
let methods = self.methods;
let conn_counter = Arc::new(());
let cfg = self.cfg;
let mut id = 0;

while let Some(socket) = incoming.next().await {
if let Ok(socket) = socket {
socket.set_nodelay(true).unwrap_or_else(|e| panic!("Could not set NODELAY on socket: {:?}", e));

if Arc::strong_count(&methods) > self.cfg.max_connections as usize {
if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize {
log::warn!("Too many connections. Try again in a while");
continue;
}
let methods = methods.clone();
let counter = conn_counter.clone();

tokio::spawn(background_task(socket, id, methods, cfg));
tokio::spawn(async move {
let r = background_task(socket, id, methods, cfg).await;
drop(counter);
r
});

id += 1;
}
Expand All @@ -99,7 +105,7 @@ impl Server {
async fn background_task(
socket: tokio::net::TcpStream,
conn_id: ConnectionId,
methods: Arc<Methods>,
methods: Methods,
cfg: Settings,
) -> Result<(), Error> {
// For each incoming background_task we perform a handshake.
Expand Down

0 comments on commit 6c69a8c

Please sign in to comment.