Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add a way to limit the number of subscriptions per connection #739

Merged
merged 13 commits into from
May 3, 2022
2 changes: 1 addition & 1 deletion benches/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jsonrpc-http-server = { version = "18.0.0", optional = true }
jsonrpc-pubsub = { version = "18.0.0", optional = true }
num_cpus = "1"
serde_json = "1"
tokio = { version = "1.8", features = ["rt-multi-thread"] }
tokio = { version = "1.16", features = ["rt-multi-thread"] }

[[bench]]
name = "bench"
Expand Down
4 changes: 2 additions & 2 deletions client/http-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ jsonrpsee-core = { path = "../../core", version = "0.11.0", features = ["client"
serde = { version = "1.0", default-features = false, features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
tokio = { version = "1.8", features = ["time"] }
tokio = { version = "1.16", features = ["time"] }
tracing = "0.1"

[dev-dependencies]
jsonrpsee-test-utils = { path = "../../test-utils" }
tokio = { version = "1.8", features = ["net", "rt-multi-thread", "macros"] }
tokio = { version = "1.16", features = ["net", "rt-multi-thread", "macros"] }

[features]
default = ["tls"]
Expand Down
4 changes: 2 additions & 2 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ rustc-hash = { version = "1", optional = true }
rand = { version = "0.8", optional = true }
soketto = { version = "0.7.1", optional = true }
parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.8", optional = true }
tokio = { version = "1.16", optional = true }
wasm-bindgen-futures = { version = "0.4.19", optional = true }
futures-timer = { version = "3", optional = true }

Expand Down Expand Up @@ -66,5 +66,5 @@ async-wasm-client = [

[dev-dependencies]
serde_json = "1.0"
tokio = { version = "1.8", features = ["macros", "rt"] }
tokio = { version = "1.16", features = ["macros", "rt"] }
jsonrpsee = { path = "../jsonrpsee", features = ["server", "macros"] }
5 changes: 5 additions & 0 deletions core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ impl<Notif> Subscription<Notif> {
) -> Self {
Self { to_back, notifs_rx, kind, marker: PhantomData }
}

/// Return the subscription type and, if applicable, ID.
pub fn kind(&self) -> &SubscriptionKind {
&self.kind
}
}

/// Batch request message.
Expand Down
61 changes: 61 additions & 0 deletions core/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
// DEALINGS IN THE SOFTWARE.

use std::io;
use std::sync::Arc;

use crate::{to_json_raw_value, Error};
use futures_channel::mpsc;
use futures_util::StreamExt;
use jsonrpsee_types::error::{ErrorCode, ErrorObject, ErrorResponse, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG};
use jsonrpsee_types::{Id, InvalidRequest, Response};
use serde::Serialize;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore};

/// Bounded writer that allows writing at most `max_len` bytes.
///
Expand Down Expand Up @@ -196,8 +198,53 @@ pub async fn collect_batch_response(rx: mpsc::UnboundedReceiver<String>) -> Stri
buf
}

/// A permitted subscription.
#[derive(Debug)]
pub struct SubscriptionPermit {
_permit: OwnedSemaphorePermit,
resource: Arc<Notify>,
}

impl SubscriptionPermit {
/// Get the handle to [`tokio::sync::Notify`].
pub fn handle(&self) -> Arc<Notify> {
self.resource.clone()
}
}

/// Wrapper over [`tokio::sync::Notify`] with bounds check.
#[derive(Debug, Clone)]
pub struct BoundedSubscriptions {
jsdw marked this conversation as resolved.
Show resolved Hide resolved
resource: Arc<Notify>,
guard: Arc<Semaphore>,
}

impl BoundedSubscriptions {
/// Create a new bounded subscription.
pub fn new(max_subscriptions: u32) -> Self {
Self { resource: Arc::new(Notify::new()), guard: Arc::new(Semaphore::new(max_subscriptions as usize)) }
}

/// Attempts to acquire a subscription slot.
///
/// Fails if `max_subscriptions` have been exceeded.
pub fn acquire(&self) -> Option<SubscriptionPermit> {
Arc::clone(&self.guard)
.try_acquire_owned()
.ok()
.map(|p| SubscriptionPermit { _permit: p, resource: self.resource.clone() })
}

/// Close all subscriptions.
pub fn close(&self) {
self.resource.notify_waiters();
}
}

#[cfg(test)]
mod tests {
use crate::server::helpers::BoundedSubscriptions;

use super::{BoundedWriter, Id, Response};

#[test]
Expand All @@ -215,4 +262,18 @@ mod tests {
// NOTE: `"` is part of the serialization so 101 characters.
assert!(serde_json::to_writer(&mut writer, &"x".repeat(99)).is_err());
}

#[test]
fn bounded_subscriptions_work() {
let subs = BoundedSubscriptions::new(5);
let mut handles = Vec::new();

for _ in 0..5 {
handles.push(subs.acquire().unwrap());
}

assert!(subs.acquire().is_none());
handles.swap_remove(0);
assert!(subs.acquire().is_some());
}
}
49 changes: 31 additions & 18 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use std::sync::Arc;

use crate::error::{Error, SubscriptionClosed};
use crate::id_providers::RandomIntegerIdProvider;
use crate::server::helpers::MethodSink;
use crate::server::helpers::{BoundedSubscriptions, MethodSink, SubscriptionPermit};
use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec, Resources};
use crate::traits::{IdProvider, ToRpcParams};
use futures_channel::mpsc;
Expand All @@ -48,7 +48,7 @@ use jsonrpsee_types::{
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::{watch, Notify};
use tokio::sync::watch;

/// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request,
/// implemented as a function pointer to a `Fn` function taking four arguments:
Expand All @@ -61,6 +61,8 @@ pub type AsyncMethod<'a> = Arc<
>;
/// Method callback for subscriptions.
pub type SubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnState) -> bool>;
// Method callback to unsubscribe.
type UnsubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnectionId) -> bool>;

/// Connection ID, used for stateful protocol such as WebSockets.
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value.
Expand All @@ -70,15 +72,15 @@ pub type ConnectionId = usize;
/// A 3-tuple containing:
/// - Call result as a `String`,
/// - a [`mpsc::UnboundedReceiver<String>`] to receive future subscription results
/// - a [`tokio::sync::Notify`] to allow subscribers to notify their [`SubscriptionSink`] when they disconnect.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, Arc<Notify>);
/// - a [`crate::server::helpers::SubscriptionPermit`] to allow subscribers to notify their [`SubscriptionSink`] when they disconnect.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, SubscriptionPermit);

/// Helper struct to manage subscriptions.
pub struct ConnState<'a> {
/// Connection ID
pub conn_id: ConnectionId,
/// Get notified when the connection to subscribers is closed.
pub close_notify: Arc<Notify>,
/// Get notified when the connection to subscribers is closed.1
jsdw marked this conversation as resolved.
Show resolved Hide resolved
pub close_notify: SubscriptionPermit,
/// ID provider.
pub id_provider: &'a dyn IdProvider,
}
Expand Down Expand Up @@ -114,8 +116,10 @@ pub enum MethodKind {
Sync(SyncMethod),
/// Asynchronous method handler.
Async(AsyncMethod<'static>),
/// Subscription method handler
/// Subscription method handler.
Subscription(SubscriptionMethod),
/// Unsubscription method handler.
Unsubscription(UnsubscriptionMethod),
}

/// Information about resources the method uses during its execution. Initialized when the the server starts.
Expand Down Expand Up @@ -189,6 +193,13 @@ impl MethodCallback {
}
}

fn new_unsubscription(callback: UnsubscriptionMethod) -> Self {
MethodCallback {
callback: MethodKind::Unsubscription(callback),
resources: MethodResources::Uninitialized([].into()),
}
}

/// Attempt to claim resources prior to executing a method. On success returns a guard that releases
/// claimed resources when dropped.
pub fn claim(&self, name: &str, resources: &Resources) -> Result<ResourceGuard, Error> {
Expand All @@ -210,6 +221,7 @@ impl Debug for MethodKind {
Self::Async(_) => write!(f, "Async"),
Self::Sync(_) => write!(f, "Sync"),
Self::Subscription(_) => write!(f, "Subscription"),
Self::Unsubscription(_) => write!(f, "Unsubscription"),
}
}
}
Expand Down Expand Up @@ -393,17 +405,19 @@ impl Methods {
let sink = MethodSink::new(tx_sink);
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
let notify = Arc::new(Notify::new());
let bounded_subs = BoundedSubscriptions::new(u32::MAX);
let close_notify = bounded_subs.acquire().unwrap();
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
let notify = bounded_subs.acquire().unwrap();
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved

let _result = match self.method(&req.method).map(|c| &c.callback) {
None => sink.send_error(req.id, ErrorCode::MethodNotFound.into()),
Some(MethodKind::Sync(cb)) => (cb)(id, params, &sink),
Some(MethodKind::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), sink, 0, None).await,
Some(MethodKind::Subscription(cb)) => {
let close_notify = notify.clone();
let conn_state = ConnState { conn_id: 0, close_notify, id_provider: &RandomIntegerIdProvider };
(cb)(id, params, &sink, conn_state)
}
Some(MethodKind::Unsubscription(cb)) => (cb)(id, params, &sink, 0),
};

let resp = rx_sink.next().await.expect("tx and rx still alive; qed");
Expand Down Expand Up @@ -707,7 +721,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
{
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
MethodCallback::new_subscription(Arc::new(move |id, params, sink, conn| {
MethodCallback::new_unsubscription(Arc::new(move |id, params, sink, conn_id| {
let sub_id = match params.one::<RpcSubscriptionId>() {
Ok(sub_id) => sub_id,
Err(_) => {
Expand All @@ -722,8 +736,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
};
let sub_id = sub_id.into_owned();

let result =
subscribers.lock().remove(&SubscriptionKey { conn_id: conn.conn_id, sub_id }).is_some();
let result = subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id }).is_some();

sink.send_response(id, result)
})),
Expand Down Expand Up @@ -757,7 +770,7 @@ struct InnerPendingSubscription {
/// Sink.
sink: MethodSink,
/// Get notified when subscribers leave so we can exit
close_notify: Option<Arc<Notify>>,
close_notify: Option<SubscriptionPermit>,
/// MethodCallback.
method: &'static str,
/// Unique subscription.
Expand Down Expand Up @@ -819,7 +832,7 @@ pub struct SubscriptionSink {
/// Sink.
inner: MethodSink,
/// Get notified when subscribers leave so we can exit
close_notify: Option<Arc<Notify>>,
close_notify: Option<SubscriptionPermit>,
/// MethodCallback.
method: &'static str,
/// Unique subscription.
Expand Down Expand Up @@ -892,8 +905,8 @@ impl SubscriptionSink {
T: Serialize,
E: std::fmt::Display,
{
let conn_closed = match self.close_notify.clone() {
Some(close_notify) => close_notify,
let conn_closed = match self.close_notify.as_ref().map(|cn| cn.handle()) {
Some(cn) => cn,
None => {
return SubscriptionClosed::RemotePeerAborted;
}
Expand Down Expand Up @@ -1035,7 +1048,7 @@ impl Drop for SubscriptionSink {
/// Wrapper struct that maintains a subscription "mainly" for testing.
#[derive(Debug)]
pub struct Subscription {
close_notify: Option<Arc<Notify>>,
close_notify: Option<SubscriptionPermit>,
rx: mpsc::UnboundedReceiver<String>,
sub_id: RpcSubscriptionId<'static>,
}
Expand All @@ -1045,7 +1058,7 @@ impl Subscription {
pub fn close(&mut self) {
tracing::trace!("[Subscription::close] Notifying");
if let Some(n) = self.close_notify.take() {
n.notify_one()
n.handle().notify_one()
}
}
/// Get the subscription ID
Expand Down
2 changes: 1 addition & 1 deletion examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ futures = "0.3"
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
tokio = { version = "1.8", features = ["full"] }
tokio = { version = "1.16", features = ["full"] }
tokio-stream = { version = "0.1", features = ["sync"] }
serde_json = { version = "1" }

Expand Down
2 changes: 1 addition & 1 deletion http-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ globset = "0.4"
lazy_static = "1.4"
tracing = "0.1"
serde_json = "1"
tokio = { version = "1.8", features = ["rt-multi-thread", "macros"] }
tokio = { version = "1.16", features = ["rt-multi-thread", "macros"] }
unicase = "2.6.0"

[dev-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions http-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ async fn process_validated_request(
false
}
},
MethodKind::Subscription(_) => {
MethodKind::Subscription(_) | MethodKind::Unsubscription(_) => {
tracing::error!("Subscriptions not supported on HTTP");
sink.send_error(req.id, ErrorCode::InternalError.into());
false
Expand Down Expand Up @@ -622,7 +622,7 @@ async fn process_validated_request(
None
}
},
MethodKind::Subscription(_) => {
MethodKind::Subscription(_) | MethodKind::Unsubscription(_) => {
tracing::error!("Subscriptions not supported on HTTP");
sink.send_error(req.id, ErrorCode::InternalError.into());
middleware.on_result(&req.method, false, request_start);
Expand Down
2 changes: 1 addition & 1 deletion proc-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ proc-macro-crate = "1"
[dev-dependencies]
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
trybuild = "1.0"
tokio = { version = "1.8", features = ["rt", "macros"] }
tokio = { version = "1.16", features = ["rt", "macros"] }
futures-channel = { version = "0.3.14", default-features = false }
futures-util = { version = "0.3.14", default-features = false }
2 changes: 1 addition & 1 deletion test-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ tracing = "0.1"
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = "1"
soketto = { version = "0.7.1", features = ["http"] }
tokio = { version = "1.8", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio = { version = "1.16", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
2 changes: 1 addition & 1 deletion tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ env_logger = "0.9"
beef = { version = "0.5.1", features = ["impl_serde"] }
futures = { version = "0.3.14", default-features = false, features = ["std"] }
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
tokio = { version = "1.8", features = ["full"] }
tokio = { version = "1.16", features = ["full"] }
tracing = "0.1"
serde = "1"
serde_json = "1"
Expand Down
Loading