Skip to content

Commit

Permalink
feat: limit the number of subscriptions
Browse files Browse the repository at this point in the history
Closing #729
  • Loading branch information
niklasad1 committed Apr 26, 2022
1 parent e6c6ac5 commit a38045c
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 16 deletions.
49 changes: 49 additions & 0 deletions core/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
// DEALINGS IN THE SOFTWARE.

use std::io;
use std::sync::Arc;

use crate::{to_json_raw_value, Error};
use futures_channel::mpsc;
use futures_util::StreamExt;
use jsonrpsee_types::error::{ErrorCode, ErrorObject, ErrorResponse, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG};
use jsonrpsee_types::{Id, InvalidRequest, Response};
use serde::Serialize;
use tokio::sync::Notify;

/// Bounded writer that allows writing at most `max_len` bytes.
///
Expand Down Expand Up @@ -196,8 +198,41 @@ pub async fn collect_batch_response(rx: mpsc::UnboundedReceiver<String>) -> Stri
buf
}

/// Wrapper over [`tokio::sync::Notify`] with bounds check.
#[derive(Debug)]
pub struct BoundedSubscriptions {
inner: Arc<Notify>,
max_subscriptions: u32,
}

impl BoundedSubscriptions {
/// Create a new bounded subscription.
pub fn new(max_subscriptions: u32) -> Self {
Self { inner: Arc::new(Notify::new()), max_subscriptions }
}

/// The get a handle to a subscription
///
/// Fails if `max_subscriptions` have been exceeded.
pub fn get(&self) -> Option<Arc<Notify>> {
// The type itself increases the strong count by
if Arc::strong_count(&self.inner) as u32 > self.max_subscriptions {
None
} else {
Some(self.inner.clone())
}
}

/// Close all subscriptions.
pub fn close(&self) {
self.inner.notify_waiters();
}
}

#[cfg(test)]
mod tests {
use crate::server::helpers::BoundedSubscriptions;

use super::{BoundedWriter, Id, Response};

#[test]
Expand All @@ -215,4 +250,18 @@ mod tests {
// NOTE: `"` is part of the serialization so 101 characters.
assert!(serde_json::to_writer(&mut writer, &"x".repeat(99)).is_err());
}

#[test]
fn bounded_subscriptions_work() {
let subs = Arc::new(BoundedSubscriptions::new(5));
let mut handles = Vec::new();

for _ in 0..5 {
handles.push(subs.get().unwrap());
}

assert!(subs.get().is_none());
handles.swap_remove(0);
assert!(subs.get().is_some());
}
}
45 changes: 29 additions & 16 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use futures_util::io::{BufReader, BufWriter};
use futures_util::stream::StreamExt;
use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
use jsonrpsee_core::middleware::Middleware;
use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, MethodSink};
use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, BoundedSubscriptions, MethodSink};
use jsonrpsee_core::server::resource_limiting::Resources;
use jsonrpsee_core::server::rpc_module::{ConnState, ConnectionId, MethodKind, Methods};
use jsonrpsee_core::traits::IdProvider;
Expand All @@ -49,7 +49,6 @@ use soketto::connection::Error as SokettoError;
use soketto::handshake::{server::Response, Server as SokettoServer};
use soketto::Sender;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::sync::Notify;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};

/// Default maximum connections allowed.
Expand Down Expand Up @@ -271,6 +270,7 @@ where
resources.clone(),
cfg.max_request_body_size,
cfg.max_response_body_size,
BoundedSubscriptions::new(cfg.max_subscriptions_per_connection),
stop_monitor.clone(),
middleware,
id_provider,
Expand All @@ -292,6 +292,7 @@ async fn background_task(
resources: Resources,
max_request_body_size: u32,
max_response_body_size: u32,
bounded_subscriptions: BoundedSubscriptions,
stop_server: StopMonitor,
middleware: impl Middleware,
id_provider: Arc<dyn IdProvider>,
Expand All @@ -301,8 +302,8 @@ async fn background_task(
builder.set_max_message_size(max_request_body_size as usize);
let (mut sender, mut receiver) = builder.finish();
let (tx, mut rx) = mpsc::unbounded::<String>();
let close_notify = Arc::new(Notify::new());
let close_notify_server_stop = close_notify.clone();
let bounded_subscriptions = Arc::new(bounded_subscriptions);
let bounded_subscriptions2 = bounded_subscriptions.clone();

let stop_server2 = stop_server.clone();
let sink = MethodSink::new_with_limit(tx, max_response_body_size);
Expand All @@ -327,7 +328,7 @@ async fn background_task(
let _ = sender.close().await;

// Notify all listeners and close down associated tasks.
close_notify_server_stop.notify_waiters();
bounded_subscriptions2.close();
});

// Buffer for incoming data.
Expand Down Expand Up @@ -436,11 +437,14 @@ async fn background_task(
},
MethodKind::Subscription(callback) => match method.claim(&req.method, &resources) {
Ok(guard) => {
let cn = close_notify.clone();
let conn_state =
ConnState { conn_id, close_notify: cn, id_provider: &*id_provider };

let result = callback(id, params, &sink, conn_state);
let result = if let Some(cn) = bounded_subscriptions.get() {
let conn_state =
ConnState { conn_id, close_notify: cn, id_provider: &*id_provider };
callback(id, params, &sink, conn_state)
} else {
sink.send_error(req.id, ErrorCode::ServerIsBusy.into());
false
};
middleware.on_result(name, result, request_start);
middleware.on_response(request_start);
drop(guard);
Expand Down Expand Up @@ -470,7 +474,7 @@ async fn background_task(
let methods = &methods;
let sink = sink.clone();
let id_provider = id_provider.clone();
let close_notify2 = close_notify.clone();
let bounded_subscriptions2 = bounded_subscriptions.clone();

let fut = async move {
// Batch responses must be sent back as a single message so we read the results from each
Expand Down Expand Up @@ -537,11 +541,17 @@ async fn background_task(
MethodKind::Subscription(callback) => {
match method_callback.claim(&req.method, resources) {
Ok(guard) => {
let close_notify = close_notify2.clone();
let conn_state =
ConnState { conn_id, close_notify, id_provider: &*id_provider };

let result = callback(id, params, &sink_batch, conn_state);
let result = if let Some(cn) = bounded_subscriptions2.get() {
let conn_state = ConnState {
conn_id,
close_notify: cn,
id_provider: &*id_provider,
};
callback(id, params, &sink_batch, conn_state)
} else {
sink_batch.send_error(req.id, ErrorCode::ServerIsBusy.into());
false
};
middleware.on_result(&req.method, result, request_start);
drop(guard);
None
Expand Down Expand Up @@ -629,6 +639,8 @@ struct Settings {
max_response_body_size: u32,
/// Maximum number of incoming connections allowed.
max_connections: u64,
/// Maximum number of subscriptions per connection.
max_subscriptions_per_connection: u32,
/// Policy by which to accept or deny incoming requests based on the `Origin` header.
allowed_origins: AllowedValue,
/// Policy by which to accept or deny incoming requests based on the `Host` header.
Expand All @@ -642,6 +654,7 @@ impl Default for Settings {
Self {
max_request_body_size: TEN_MB_SIZE_BYTES,
max_response_body_size: TEN_MB_SIZE_BYTES,
max_subscriptions_per_connection: 1024,
max_connections: MAX_CONNECTIONS,
allowed_origins: AllowedValue::Any,
allowed_hosts: AllowedValue::Any,
Expand Down

0 comments on commit a38045c

Please sign in to comment.