From a38045ce263b69219bac7e2bd46a86cf8c666ea1 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 26 Apr 2022 16:47:28 +0200 Subject: [PATCH] feat: limit the number of subscriptions Closing #729 --- core/src/server/helpers.rs | 49 ++++++++++++++++++++++++++++++++++++++ ws-server/src/server.rs | 45 +++++++++++++++++++++------------- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/core/src/server/helpers.rs b/core/src/server/helpers.rs index 2fb7bb319d..82fd33376e 100644 --- a/core/src/server/helpers.rs +++ b/core/src/server/helpers.rs @@ -25,6 +25,7 @@ // DEALINGS IN THE SOFTWARE. use std::io; +use std::sync::Arc; use crate::{to_json_raw_value, Error}; use futures_channel::mpsc; @@ -32,6 +33,7 @@ use futures_util::StreamExt; use jsonrpsee_types::error::{ErrorCode, ErrorObject, ErrorResponse, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG}; use jsonrpsee_types::{Id, InvalidRequest, Response}; use serde::Serialize; +use tokio::sync::Notify; /// Bounded writer that allows writing at most `max_len` bytes. /// @@ -196,8 +198,41 @@ pub async fn collect_batch_response(rx: mpsc::UnboundedReceiver) -> Stri buf } +/// Wrapper over [`tokio::sync::Notify`] with bounds check. +#[derive(Debug)] +pub struct BoundedSubscriptions { + inner: Arc, + max_subscriptions: u32, +} + +impl BoundedSubscriptions { + /// Create a new bounded subscription. + pub fn new(max_subscriptions: u32) -> Self { + Self { inner: Arc::new(Notify::new()), max_subscriptions } + } + + /// The get a handle to a subscription + /// + /// Fails if `max_subscriptions` have been exceeded. + pub fn get(&self) -> Option> { + // The type itself increases the strong count by + if Arc::strong_count(&self.inner) as u32 > self.max_subscriptions { + None + } else { + Some(self.inner.clone()) + } + } + + /// Close all subscriptions. + pub fn close(&self) { + self.inner.notify_waiters(); + } +} + #[cfg(test)] mod tests { + use crate::server::helpers::BoundedSubscriptions; + use super::{BoundedWriter, Id, Response}; #[test] @@ -215,4 +250,18 @@ mod tests { // NOTE: `"` is part of the serialization so 101 characters. assert!(serde_json::to_writer(&mut writer, &"x".repeat(99)).is_err()); } + + #[test] + fn bounded_subscriptions_work() { + let subs = Arc::new(BoundedSubscriptions::new(5)); + let mut handles = Vec::new(); + + for _ in 0..5 { + handles.push(subs.get().unwrap()); + } + + assert!(subs.get().is_none()); + handles.swap_remove(0); + assert!(subs.get().is_some()); + } } diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 5e1b2872bd..136dc008fe 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -39,7 +39,7 @@ use futures_util::io::{BufReader, BufWriter}; use futures_util::stream::StreamExt; use jsonrpsee_core::id_providers::RandomIntegerIdProvider; use jsonrpsee_core::middleware::Middleware; -use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, MethodSink}; +use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, BoundedSubscriptions, MethodSink}; use jsonrpsee_core::server::resource_limiting::Resources; use jsonrpsee_core::server::rpc_module::{ConnState, ConnectionId, MethodKind, Methods}; use jsonrpsee_core::traits::IdProvider; @@ -49,7 +49,6 @@ use soketto::connection::Error as SokettoError; use soketto::handshake::{server::Response, Server as SokettoServer}; use soketto::Sender; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tokio::sync::Notify; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; /// Default maximum connections allowed. @@ -271,6 +270,7 @@ where resources.clone(), cfg.max_request_body_size, cfg.max_response_body_size, + BoundedSubscriptions::new(cfg.max_subscriptions_per_connection), stop_monitor.clone(), middleware, id_provider, @@ -292,6 +292,7 @@ async fn background_task( resources: Resources, max_request_body_size: u32, max_response_body_size: u32, + bounded_subscriptions: BoundedSubscriptions, stop_server: StopMonitor, middleware: impl Middleware, id_provider: Arc, @@ -301,8 +302,8 @@ async fn background_task( builder.set_max_message_size(max_request_body_size as usize); let (mut sender, mut receiver) = builder.finish(); let (tx, mut rx) = mpsc::unbounded::(); - let close_notify = Arc::new(Notify::new()); - let close_notify_server_stop = close_notify.clone(); + let bounded_subscriptions = Arc::new(bounded_subscriptions); + let bounded_subscriptions2 = bounded_subscriptions.clone(); let stop_server2 = stop_server.clone(); let sink = MethodSink::new_with_limit(tx, max_response_body_size); @@ -327,7 +328,7 @@ async fn background_task( let _ = sender.close().await; // Notify all listeners and close down associated tasks. - close_notify_server_stop.notify_waiters(); + bounded_subscriptions2.close(); }); // Buffer for incoming data. @@ -436,11 +437,14 @@ async fn background_task( }, MethodKind::Subscription(callback) => match method.claim(&req.method, &resources) { Ok(guard) => { - let cn = close_notify.clone(); - let conn_state = - ConnState { conn_id, close_notify: cn, id_provider: &*id_provider }; - - let result = callback(id, params, &sink, conn_state); + let result = if let Some(cn) = bounded_subscriptions.get() { + let conn_state = + ConnState { conn_id, close_notify: cn, id_provider: &*id_provider }; + callback(id, params, &sink, conn_state) + } else { + sink.send_error(req.id, ErrorCode::ServerIsBusy.into()); + false + }; middleware.on_result(name, result, request_start); middleware.on_response(request_start); drop(guard); @@ -470,7 +474,7 @@ async fn background_task( let methods = &methods; let sink = sink.clone(); let id_provider = id_provider.clone(); - let close_notify2 = close_notify.clone(); + let bounded_subscriptions2 = bounded_subscriptions.clone(); let fut = async move { // Batch responses must be sent back as a single message so we read the results from each @@ -537,11 +541,17 @@ async fn background_task( MethodKind::Subscription(callback) => { match method_callback.claim(&req.method, resources) { Ok(guard) => { - let close_notify = close_notify2.clone(); - let conn_state = - ConnState { conn_id, close_notify, id_provider: &*id_provider }; - - let result = callback(id, params, &sink_batch, conn_state); + let result = if let Some(cn) = bounded_subscriptions2.get() { + let conn_state = ConnState { + conn_id, + close_notify: cn, + id_provider: &*id_provider, + }; + callback(id, params, &sink_batch, conn_state) + } else { + sink_batch.send_error(req.id, ErrorCode::ServerIsBusy.into()); + false + }; middleware.on_result(&req.method, result, request_start); drop(guard); None @@ -629,6 +639,8 @@ struct Settings { max_response_body_size: u32, /// Maximum number of incoming connections allowed. max_connections: u64, + /// Maximum number of subscriptions per connection. + max_subscriptions_per_connection: u32, /// Policy by which to accept or deny incoming requests based on the `Origin` header. allowed_origins: AllowedValue, /// Policy by which to accept or deny incoming requests based on the `Host` header. @@ -642,6 +654,7 @@ impl Default for Settings { Self { max_request_body_size: TEN_MB_SIZE_BYTES, max_response_body_size: TEN_MB_SIZE_BYTES, + max_subscriptions_per_connection: 1024, max_connections: MAX_CONNECTIONS, allowed_origins: AllowedValue::Any, allowed_hosts: AllowedValue::Any,