Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rpc module): close subscription task when a subscription is unsubscribed via the unsubscribe call #743

Merged
merged 16 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 48 additions & 38 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::id_providers::RandomIntegerIdProvider;
use crate::server::helpers::MethodSink;
use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec, Resources};
use crate::traits::{IdProvider, ToRpcParams};
use futures_channel::mpsc;
use futures_channel::{mpsc, oneshot};
use futures_util::future::Either;
use futures_util::pin_mut;
use futures_util::{future::BoxFuture, FutureExt, Stream, StreamExt, TryStream, TryStreamExt};
Expand Down Expand Up @@ -98,7 +98,7 @@ impl<'a> std::fmt::Debug for ConnState<'a> {
}
}

type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, Arc<()>)>>>;
type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, oneshot::Sender<()>)>>>;

/// Represent a unique subscription entry based on [`RpcSubscriptionId`] and [`ConnectionId`].
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -794,9 +794,9 @@ impl PendingSubscription {
let InnerPendingSubscription { sink, close_notify, method, uniq_sub, subscribers, id } = inner;

if sink.send_response(id, &uniq_sub.sub_id) {
let active_sub = Arc::new(());
subscribers.lock().insert(uniq_sub.clone(), (sink.clone(), active_sub.clone()));
Some(SubscriptionSink { inner: sink, close_notify, method, uniq_sub, subscribers, active_sub })
let (tx, rx) = oneshot::channel();
subscribers.lock().insert(uniq_sub.clone(), (sink.clone(), tx));
Some(SubscriptionSink { inner: sink, close_notify, method, uniq_sub, subscribers, unsubscribe: Some(rx) })
} else {
None
}
Expand Down Expand Up @@ -826,7 +826,8 @@ pub struct SubscriptionSink {
uniq_sub: SubscriptionKey,
/// Shared Mutex of subscriptions for this method.
subscribers: Subscribers,
active_sub: Arc<()>,
/// Future that returns when the `unsubscribe method ` has been called.
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
unsubscribe: Option<oneshot::Receiver<()>>,
}

impl SubscriptionSink {
Expand All @@ -843,7 +844,7 @@ impl SubscriptionSink {
}

let msg = self.build_message(result)?;
Ok(self.inner_send(msg))
Ok(self.inner.send_raw(msg).is_ok())
}

/// Reads data from the `stream` and sends back data on the subscription
Expand Down Expand Up @@ -871,7 +872,7 @@ impl SubscriptionSink {
/// tokio::spawn(async move {
/// // jsonrpsee doesn't send an error notification unless `close` is explicitly called.
/// // If we pipe messages to the sink, we can inspect why it ended:
/// match sink.pipe_from_try_stream(stream).await {
/// sink.pipe_from_try_stream(stream, |close, sink| match close {
/// SubscriptionClosed::Success => {
/// let err_obj: ErrorObjectOwned = SubscriptionClosed::Success.into();
/// sink.close(err_obj);
Expand All @@ -881,25 +882,40 @@ impl SubscriptionSink {
/// SubscriptionClosed::Failed(e) => {
/// sink.close(e);
/// }
/// };
/// }).await;
/// });
/// });
/// ```
pub async fn pipe_from_try_stream<S, T, E>(&mut self, mut stream: S) -> SubscriptionClosed
pub async fn pipe_from_try_stream<S, T, E, F>(mut self, mut stream: S, on_close: F)
where
S: TryStream<Ok = T, Error = E> + Unpin,
T: Serialize,
E: std::fmt::Display,
F: FnOnce(SubscriptionClosed, SubscriptionSink),
{
let close_notify = match self.close_notify.clone() {
let conn_closed = match self.close_notify.clone() {
Some(close_notify) => close_notify,
None => return SubscriptionClosed::RemotePeerAborted,
None => {
on_close(SubscriptionClosed::RemotePeerAborted, self);
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
return;
}
};

let sub_closed = match self.unsubscribe.take() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, so this oneshot unsubscribe channel is removed here, which means we can only use pipe_from_try_stream once.

Is it possible to take a mutable ref to it instead to use below (which may need pinning)? Maybe then this method would still be reusable?

Copy link
Member Author

@niklasad1 niklasad1 Apr 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it won't work because send takes a &self as well.

I changed it to tokio::sync::watch for ease of use.

Some(unsub) => unsub,
None => {
on_close(SubscriptionClosed::RemotePeerAborted, self);
return;
}
};

let conn_closed_fut = conn_closed.notified();
pin_mut!(conn_closed_fut);

let mut stream_item = stream.try_next();
let closed_fut = close_notify.notified();
pin_mut!(closed_fut);
loop {
let mut closed_fut = futures_util::future::select(conn_closed_fut, sub_closed);

let close = loop {
match futures_util::future::select(stream_item, closed_fut).await {
// The app sent us a value to send back to the subscribers
Either::Left((Ok(Some(result)), next_closed_fut)) => {
Expand All @@ -922,11 +938,13 @@ impl SubscriptionSink {
break SubscriptionClosed::Failed(err);
}
Either::Left((Ok(None), _)) => break SubscriptionClosed::Success,
Either::Right(((), _)) => {
Either::Right((_, _)) => {
break SubscriptionClosed::RemotePeerAborted;
}
}
}
};

on_close(close, self);
}

/// Similar to [`SubscriptionSink::pipe_from_try_stream`] but it doesn't require the stream return `Result`.
Expand All @@ -945,24 +963,25 @@ impl SubscriptionSink {
/// m.register_subscription("sub", "_", "unsub", |params, pending, _| {
/// let mut sink = pending.accept().unwrap();
/// let stream = futures_util::stream::iter(vec![1_usize, 2, 3]);
/// tokio::spawn(async move { sink.pipe_from_stream(stream).await; });
/// tokio::spawn(async move { sink.pipe_from_stream(stream, |_, _| {}).await; });
/// });
/// ```
pub async fn pipe_from_stream<S, T>(&mut self, stream: S) -> SubscriptionClosed
pub async fn pipe_from_stream<S, T, F>(self, stream: S, on_close: F)
where
S: Stream<Item = T> + Unpin,
T: Serialize,
F: FnOnce(SubscriptionClosed, SubscriptionSink),
{
self.pipe_from_try_stream::<_, _, Error>(stream.map(|item| Ok(item))).await
self.pipe_from_try_stream::<_, _, Error, _>(stream.map(|item| Ok(item)), on_close).await
}

/// Returns whether this channel is closed without needing a context.
/// Returns whether the subscription is closed.
pub fn is_closed(&self) -> bool {
self.inner.is_closed() || self.close_notify.is_none()
self.inner.is_closed() || self.close_notify.is_none() || !self.is_active_subscription()
}

fn is_active_subscription(&self) -> bool {
Arc::strong_count(&self.active_sub) > 1
self.subscribers.lock().contains_key(&self.uniq_sub)
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}

fn build_message<T: Serialize>(&self, result: &T) -> Result<String, serde_json::Error> {
Expand All @@ -981,14 +1000,6 @@ impl SubscriptionSink {
.map_err(Into::into)
}

fn inner_send(&mut self, msg: String) -> bool {
if self.is_active_subscription() {
self.inner.send_raw(msg).is_ok()
} else {
false
}
}

/// Close the subscription, sending a notification with a special `error` field containing the provided error.
///
/// This can be used to signal an actual error, or just to signal that the subscription has been closed,
Expand All @@ -1010,15 +1021,14 @@ impl SubscriptionSink {
/// ```
///
pub fn close(self, err: impl Into<ErrorObjectOwned>) -> bool {
if self.is_active_subscription() {
if let Some((sink, _)) = self.subscribers.lock().remove(&self.uniq_sub) {
tracing::debug!("Closing subscription: {:?}", self.uniq_sub.sub_id);
if let Some((sink, _)) = self.subscribers.lock().remove(&self.uniq_sub) {
tracing::debug!("Closing subscription: {:?}", self.uniq_sub.sub_id);

let msg = self.build_error_message(&err.into()).expect("valid json infallible; qed");
return sink.send_raw(msg).is_ok();
}
let msg = self.build_error_message(&err.into()).expect("valid json infallible; qed");
sink.send_raw(msg).is_ok()
} else {
false
}
false
}
}

Expand Down
7 changes: 4 additions & 3 deletions examples/ws_pubsub_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,22 @@ async fn run_server() -> anyhow::Result<SocketAddr> {

module.register_subscription("subscribe_hello", "s_hello", "unsubscribe_hello", move |_, pending, _| {
let rx = BroadcastStream::new(tx.clone().subscribe());
let mut sink = match pending.accept() {
let sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};

tokio::spawn(async move {
match sink.pipe_from_try_stream(rx).await {
sink.pipe_from_try_stream(rx, |reason, sink| match reason {
SubscriptionClosed::Success => {
sink.close(SubscriptionClosed::Success);
}
SubscriptionClosed::RemotePeerAborted => (),
SubscriptionClosed::Failed(err) => {
sink.close(err);
}
};
})
.await;
});
})?;
let addr = server.local_addr()?;
Expand Down
14 changes: 8 additions & 6 deletions examples/ws_pubsub_with_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async fn run_server() -> anyhow::Result<SocketAddr> {
let mut module = RpcModule::new(());
module
.register_subscription("sub_one_param", "sub_one_param", "unsub_one_param", |params, pending, _| {
let (idx, mut sink) = match (params.one(), pending.accept()) {
let (idx, sink) = match (params.one(), pending.accept()) {
(Ok(idx), Some(sink)) => (idx, sink),
_ => return,
};
Expand All @@ -77,7 +77,7 @@ async fn run_server() -> anyhow::Result<SocketAddr> {
let stream = IntervalStream::new(interval).map(move |_| item);

tokio::spawn(async move {
match sink.pipe_from_stream(stream).await {
sink.pipe_from_stream(stream, |reason, sink| match reason {
// Send close notification when subscription stream failed.
SubscriptionClosed::Failed(err) => {
sink.close(err);
Expand All @@ -86,13 +86,14 @@ async fn run_server() -> anyhow::Result<SocketAddr> {
SubscriptionClosed::Success => (),
// Don't send close because the client has already disconnected.
SubscriptionClosed::RemotePeerAborted => (),
};
})
.await;
});
})
.unwrap();
module
.register_subscription("sub_params_two", "params_two", "unsub_params_two", |params, pending, _| {
let (one, two, mut sink) = match (params.parse::<(usize, usize)>(), pending.accept()) {
let (one, two, sink) = match (params.parse::<(usize, usize)>(), pending.accept()) {
(Ok((one, two)), Some(sink)) => (one, two, sink),
_ => return,
};
Expand All @@ -103,7 +104,7 @@ async fn run_server() -> anyhow::Result<SocketAddr> {
let stream = IntervalStream::new(interval).map(move |_| item);

tokio::spawn(async move {
match sink.pipe_from_stream(stream).await {
sink.pipe_from_stream(stream, |reason, sink| match reason {
// Send close notification when subscription stream failed.
SubscriptionClosed::Failed(err) => {
sink.close(err);
Expand All @@ -112,7 +113,8 @@ async fn run_server() -> anyhow::Result<SocketAddr> {
SubscriptionClosed::Success => (),
// Don't send close because the client has already disconnected.
SubscriptionClosed::RemotePeerAborted => (),
};
})
.await
});
})
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion proc-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ pub(crate) mod visitor;
/// let mut sink = pending.accept().unwrap();
/// tokio::spawn(async move {
/// let stream = futures_util::stream::iter(["one", "two", "three"]);
/// sink.pipe_from_stream(stream).await;
/// sink.pipe_from_stream(stream, |_, _| {}).await;
/// });
/// }
///
Expand Down
82 changes: 82 additions & 0 deletions tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@
use std::net::SocketAddr;
use std::time::Duration;

use futures::{SinkExt, StreamExt};
use jsonrpsee::core::error::SubscriptionClosed;
use jsonrpsee::http_server::{AccessControl, HttpServerBuilder, HttpServerHandle};
use jsonrpsee::types::error::{ErrorObject, SUBSCRIPTION_CLOSED_WITH_ERROR};
use jsonrpsee::ws_server::{WsServerBuilder, WsServerHandle};
use jsonrpsee::RpcModule;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;

pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle) {
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
Expand Down Expand Up @@ -108,6 +112,56 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle
})
.unwrap();

module
.register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, pending, _| {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we can call pipe_from_stream more than once, should we also have a test to make sure that we can?

let sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};

tokio::spawn(async move {
let interval = interval(Duration::from_millis(50));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);

sink.pipe_from_stream(stream, |close, sink| match close {
SubscriptionClosed::Success => {
sink.close(SubscriptionClosed::Success);
}
_ => unreachable!(),
})
.await;
});
})
.unwrap();

module
.register_subscription(
"subscribe_with_err_on_stream",
"n",
"unsubscribe_with_err_on_stream",
move |_, pending, _| {
let sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};

let err: &'static str = "error on the stream";

// create stream that produce an error which will cancel the subscription.
let stream = futures::stream::iter(vec![Ok(1_u32), Err(err), Ok(2), Ok(3)]);
tokio::spawn(async move {
sink.pipe_from_try_stream(stream, |close, sink| match close {
SubscriptionClosed::Failed(e) => {
sink.close(e);
}
_ => unreachable!(),
})
.await;
});
},
)
.unwrap();

let addr = server.local_addr().unwrap();
let server_handle = server.start(module).unwrap();

Expand All @@ -133,6 +187,34 @@ pub async fn websocket_server() -> SocketAddr {
addr
}

/// Yields at one item then sleeps for an hour.
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
pub async fn websocket_server_with_sleeping_subscription(tx: futures::channel::mpsc::Sender<()>) -> SocketAddr {
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let addr = server.local_addr().unwrap();

let mut module = RpcModule::new(tx);

module
.register_subscription("subscribe_sleep", "n", "unsubscribe_sleep", |_, pending, mut tx| {
let sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};

tokio::spawn(async move {
let interval = interval(Duration::from_secs(60 * 60));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);

sink.pipe_from_stream(stream, |_, _| {}).await;
let send_back = std::sync::Arc::make_mut(&mut tx);
send_back.send(()).await.unwrap();
});
})
.unwrap();
server.start(module).unwrap();
addr
}

pub async fn http_server() -> (SocketAddr, HttpServerHandle) {
http_server_with_access_control(AccessControl::default()).await
}
Expand Down
Loading