Skip to content

Commit

Permalink
server: replace FutureDriver with tokio::spawn (#1080)
Browse files Browse the repository at this point in the history
* replace FutureDriver with mpsc and tokio::task

* tokio spawn for calls

* refactor round trip for multiple calls

* cleanup

* cleanup

* fix graceful shutdown

* minor tweaks

* add test for graceful shutdown

* add test for #585

* compile warn

* fix nit
  • Loading branch information
niklasad1 authored Apr 17, 2023
1 parent d1c68bf commit 9c58d09
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 254 deletions.
10 changes: 5 additions & 5 deletions core/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ use crate::Error;
use jsonrpsee_types::error::{ErrorCode, ErrorObject, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG};
use jsonrpsee_types::{Id, InvalidRequest, Response, ResponsePayload};
use serde::Serialize;
use tokio::sync::mpsc::{self, OwnedPermit};
use serde_json::value::to_raw_value;
use tokio::sync::mpsc::{self, Permit};

use super::{DisconnectError, SendTimeoutError, SubscriptionMessage, TrySendError};

Expand Down Expand Up @@ -146,7 +146,7 @@ impl MethodSink {

/// Waits for channel capacity. Once capacity to send one message is available, it is reserved for the caller.
pub async fn reserve(&self) -> Result<MethodSinkPermit, DisconnectError> {
match self.tx.reserve().await {
match self.tx.clone().reserve_owned().await {
Ok(permit) => Ok(MethodSinkPermit { tx: permit, max_log_length: self.max_log_length }),
Err(_) => Err(DisconnectError(SubscriptionMessage::empty())),
}
Expand All @@ -155,12 +155,12 @@ impl MethodSink {

/// A method sink with reserved spot in the bounded queue.
#[derive(Debug)]
pub struct MethodSinkPermit<'a> {
tx: Permit<'a, String>,
pub struct MethodSinkPermit {
tx: OwnedPermit<String>,
max_log_length: u32,
}

impl<'a> MethodSinkPermit<'a> {
impl MethodSinkPermit {
/// Send a JSON-RPC error to the client
pub fn send_error(self, id: Id, err: ErrorObject) {
let json = serde_json::to_string(&Response::new(ResponsePayload::<()>::Error(err.into_owned()), id))
Expand Down
102 changes: 1 addition & 101 deletions server/src/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,110 +26,10 @@

//! Utilities for handling async code.

use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures_util::future::FutureExt;
use jsonrpsee_core::Error;
use std::sync::Arc;
use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError};

/// This is a flexible collection of futures that need to be driven to completion
/// alongside some other future, such as connection handlers that need to be
/// handled along with a listener for new connections.
///
/// In order to `.await` on these futures and drive them to completion, call
/// `select_with` providing some other future, the result of which you need.
pub(crate) struct FutureDriver<F> {
futures: Vec<F>,
}

impl<F> Default for FutureDriver<F> {
fn default() -> Self {
FutureDriver { futures: Vec::new() }
}
}

impl<F> FutureDriver<F> {
/// Add a new future to this driver
pub(crate) fn add(&mut self, future: F) {
self.futures.push(future);
}
}

impl<F> FutureDriver<F>
where
F: Future + Unpin,
{
pub(crate) async fn select_with<S: Future>(&mut self, selector: S) -> S::Output {
tokio::pin!(selector);

DriverSelect { selector, driver: self }.await
}

fn drive(&mut self, cx: &mut Context) {
let mut i = 0;

while i < self.futures.len() {
if self.futures[i].poll_unpin(cx).is_ready() {
// Using `swap_remove` since we don't care about ordering
// but we do care about removing being `O(1)`.
//
// We don't increment `i` in this branch, since we now
// have a shorter length, and potentially a new value at
// current index
self.futures.swap_remove(i);
} else {
i += 1;
}
}
}
}

impl<F> Future for FutureDriver<F>
where
F: Future + Unpin,
{
type Output = ();

fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = Pin::into_inner(self);

this.drive(cx);

if this.futures.is_empty() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}

/// This is a glorified select `Future` that will attempt to drive all
/// connection futures `F` to completion on each `poll`, while also
/// handling incoming connections.
struct DriverSelect<'a, S, F> {
selector: S,
driver: &'a mut FutureDriver<F>,
}

impl<'a, R, F> Future for DriverSelect<'a, R, F>
where
R: Future + Unpin,
F: Future + Unpin,
{
type Output = R::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = Pin::into_inner(self);

this.driver.drive(cx);

this.selector.poll_unpin(cx)
}
}

/// Represent a stop handle which is a wrapper over a `multi-consumer receiver`
/// and cloning [`StopHandle`] will get a separate instance of the underlying receiver.
#[derive(Debug, Clone)]
Expand Down
36 changes: 36 additions & 0 deletions server/src/tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -802,3 +802,39 @@ async fn notif_is_ignored() {
// This call should not be answered and a timeout is regarded as "not answered"
assert!(client.send_request_text(r#"{"jsonrpc":"2.0","method":"bar"}"#).with_default_timeout().await.is_err());
}

#[tokio::test]
async fn drop_client_with_pending_calls_works() {
init_logger();

let (handle, addr) = {
let server = ServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap();

let mut module = RpcModule::new(());

module
.register_async_method("infinite_call", |_, _| async move {
futures_util::future::pending::<()>().await;
"ok"
})
.unwrap();
let addr = server.local_addr().unwrap();

(server.start(module).unwrap(), addr)
};

let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

for _ in 0..10 {
let req = r#"{"jsonrpc":"2.0","method":"infinite_call","id":1}"#;
client.send(req).with_default_timeout().await.unwrap().unwrap();
}

client.close().await.unwrap();
assert!(client.receive().await.is_err());

// Stop the server and ensure that the server doesn't wait for futures to complete
// when the connection has already been closed.
handle.stop().unwrap();
assert!(handle.stopped().with_default_timeout().await.is_ok());
}
Loading

0 comments on commit 9c58d09

Please sign in to comment.