Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace async-channel #708

Merged
merged 23 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ serde = { version = "1.0", default-features = false, features = ["derive"] }
serde_json = { version = "1", features = ["raw_value"] }
soketto = "0.7.1"
parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.8", features = ["rt"], optional = true }
tokio = { version = "1.8", features = ["rt", "sync"], optional = true }

[features]
default = []
Expand Down
128 changes: 72 additions & 56 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use crate::to_json_raw_value;
use crate::traits::{IdProvider, ToRpcParams};
use futures_channel::{mpsc, oneshot};
use futures_util::future::Either;
use futures_util::pin_mut;
use futures_util::{future::BoxFuture, FutureExt, Stream, StreamExt};
use jsonrpsee_types::error::{invalid_subscription_err, ErrorCode, CALL_EXECUTION_FAILED_CODE};
use jsonrpsee_types::{
Expand All @@ -46,6 +47,7 @@ use jsonrpsee_types::{
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::Notify;

/// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request,
/// implemented as a function pointer to a `Fn` function taking four arguments:
Expand All @@ -62,22 +64,27 @@ pub type SubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink,
/// Connection ID, used for stateful protocol such as WebSockets.
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value.
pub type ConnectionId = usize;
/// Raw RPC response.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, async_channel::Sender<()>);

/// Data for stateful connections.
/// Raw response from an RPC
/// A 3-tuple containing:
/// - Call result as a `String`,
/// - a [`mpsc::UnboundedReceiver<String>`] to receive future subscription results
/// - a [`tokio::sync::Notify`] to allow subscribers to notify their [`SubscriptionSink`] when they disconnect.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, Arc<Notify>);

/// Helper struct to manage subscriptions.
pub struct ConnState<'a> {
/// Connection ID
pub conn_id: ConnectionId,
/// Channel to know whether the connection is closed or not.
pub close: async_channel::Receiver<()>,
/// Get notified when the connection to subscribers is closed.
pub close_notify: Arc<Notify>,
/// ID provider.
pub id_provider: &'a dyn IdProvider,
}

impl<'a> std::fmt::Debug for ConnState<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close).finish()
f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close_notify).finish()
}
}

Expand Down Expand Up @@ -367,25 +374,26 @@ impl Methods {

/// Execute a callback.
async fn inner_call(&self, req: Request<'_>) -> RawRpcResponse {
let (tx, mut rx) = mpsc::unbounded();
let sink = MethodSink::new(tx);
let (close_tx, close_rx) = async_channel::unbounded();

let (tx_sink, mut rx_sink) = mpsc::unbounded();
let sink = MethodSink::new(tx_sink);
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
let notify = Arc::new(Notify::new());
dvdplm marked this conversation as resolved.
Show resolved Hide resolved

let _result = match self.method(&req.method).map(|c| &c.callback) {
None => sink.send_error(req.id, ErrorCode::MethodNotFound.into()),
Some(MethodKind::Sync(cb)) => (cb)(id, params, &sink),
Some(MethodKind::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), sink, 0, None).await,
Some(MethodKind::Subscription(cb)) => {
let conn_state = ConnState { conn_id: 0, close: close_rx, id_provider: &RandomIntegerIdProvider };
let close_notify = notify.clone();
let conn_state = ConnState { conn_id: 0, close_notify, id_provider: &RandomIntegerIdProvider };
(cb)(id, params, &sink, conn_state)
}
};

let resp = rx.next().await.expect("tx and rx still alive; qed");
(resp, rx, close_tx)
let resp = rx_sink.next().await.expect("tx and rx still alive; qed");

(resp, rx_sink, notify)
}

/// Helper to create a subscription on the `RPC module` without having to spin up a server.
Expand Down Expand Up @@ -417,10 +425,11 @@ impl Methods {
let params = params.to_rpc_params()?;
let req = Request::new(sub_method.into(), Some(&params), Id::Number(0));
tracing::trace!("[Methods::subscribe] Calling subscription method: {:?}, params: {:?}", sub_method, params);
let (response, rx, tx) = self.inner_call(req).await;
let (response, rx, close_notify) = self.inner_call(req).await;
let subscription_response = serde_json::from_str::<Response<RpcSubscriptionId>>(&response)?;
let sub_id = subscription_response.result.into_owned();
Ok(Subscription { sub_id, rx, tx })
let close_notify = Some(close_notify);
Ok(Subscription { sub_id, rx, close_notify })
}

/// Returns an `Iterator` with all the method names registered on this server.
Expand Down Expand Up @@ -627,6 +636,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let subscribers = Subscribers::default();

// Subscribe
{
let subscribers = subscribers.clone();
self.methods.mut_callbacks().insert(
Expand All @@ -647,7 +657,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

let sink = SubscriptionSink {
inner: method_sink.clone(),
close: conn.close,
close_notify: Some(conn.close_notify),
method: notif_method_name,
subscribers: subscribers.clone(),
uniq_sub: SubscriptionKey { conn_id: conn.conn_id, sub_id },
Expand All @@ -668,6 +678,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
);
}

// Unsubscribe
{
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
Expand Down Expand Up @@ -729,8 +740,8 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
pub struct SubscriptionSink {
/// Sink.
inner: MethodSink,
/// Close
close: async_channel::Receiver<()>,
/// Get notified when subscribers leave so we can exit
close_notify: Option<Arc<Notify>>,
dvdplm marked this conversation as resolved.
Show resolved Hide resolved
/// MethodCallback.
method: &'static str,
/// Unique subscription.
Expand Down Expand Up @@ -777,47 +788,46 @@ impl SubscriptionSink {
S: Stream<Item = T> + Unpin,
T: Serialize,
{
let mut close_stream = self.close.clone();
let mut item = stream.next();
let mut close = close_stream.next();

loop {
match futures_util::future::select(item, close).await {
Either::Left((Some(result), c)) => {
match self.send(&result) {
Ok(_) => (),
Err(Error::SubscriptionClosed(close_reason)) => {
self.close(&close_reason);
break Ok(());
}
Err(err) => {
tracing::error!("subscription `{}` failed to send item got error: {:?}", self.method, err);
break Err(err);
}
};
close = c;
item = stream.next();
}
// No messages should be sent over this channel
// if that occurred just ignore and continue.
Either::Right((Some(_), i)) => {
item = i;
close = close_stream.next();
}
// Connection terminated.
Either::Right((None, _)) => {
self.close(&SubscriptionClosed::new(SubscriptionClosedReason::ConnectionReset));
break Ok(());
if let Some(close_notify) = self.close_notify.clone() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the clone here @dvdplm?!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't move out the Arc<Notify> and I can't use a reference either (because send takes a mutable reference below). :/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be possible with take here but it didn't work when I tried so I just wonder why :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take() compiles but then the tests fail. :/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Played a bit with this, and not sure there is a super clean solution. You either clone it here, or somehow decouple the closed_fut from self lifetime wise, which is really hard with pinned futures.

let mut stream_item = stream.next();
let closed_fut = close_notify.notified();
pin_mut!(closed_fut);
loop {
// match futures_util::future::select(item, Box::pin(close_notify.notified())).await {
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
match futures_util::future::select(stream_item, closed_fut).await {
// The app sent us a value to send back to the subscribers
Either::Left((Some(result), next_closed_fut)) => {
match self.send(&result) {
Ok(_) => (),
Err(Error::SubscriptionClosed(close_reason)) => {
self.close(&close_reason);
break Ok(());
}
Err(err) => {
break Err(err);
}
};
stream_item = stream.next();
closed_fut = next_closed_fut;
}
// Stream terminated.
Either::Left((None, _)) => break Ok(()),
// The subscriber went away without telling us.
Either::Right(((), _)) => {
self.close(&SubscriptionClosed::new(SubscriptionClosedReason::ConnectionReset));
break Ok(());
}
}
// Stream terminated.
Either::Left((None, _)) => break Ok(()),
}
} else {
// The sink is closed.
return Ok(());
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Returns whether this channel is closed without needing a context.
pub fn is_closed(&self) -> bool {
self.inner.is_closed() || self.close.is_closed()
self.inner.is_closed() || self.close_notify.is_none()
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}

fn build_message<T: Serialize>(&self, result: &T) -> Result<String, Error> {
Expand Down Expand Up @@ -857,7 +867,7 @@ impl SubscriptionSink {
self.inner_close(Some(&close_reason));
}

/// Provide close from `SubscriptionClosed`.
/// Close the subscription sink with the provided [`SubscriptionClosed`].
pub fn close(&mut self, close_reason: &SubscriptionClosed) {
self.inner_close(Some(close_reason));
}
Expand All @@ -884,17 +894,19 @@ impl Drop for SubscriptionSink {
/// Wrapper struct that maintains a subscription "mainly" for testing.
#[derive(Debug)]
pub struct Subscription {
tx: async_channel::Sender<()>,
close_notify: Option<Arc<Notify>>,
rx: mpsc::UnboundedReceiver<String>,
sub_id: RpcSubscriptionId<'static>,
}

impl Subscription {
/// Close the subscription channel.
pub fn close(&mut self) {
self.tx.close();
tracing::trace!("[Subscription::close] Notifying");
if let Some(n) = self.close_notify.take() {
n.notify_one()
}
}

/// Get the subscription ID
pub fn subscription_id(&self) -> &RpcSubscriptionId {
&self.sub_id
Expand All @@ -907,6 +919,10 @@ impl Subscription {
///
/// If the decoding the value as `T` fails.
pub async fn next<T: DeserializeOwned>(&mut self) -> Option<Result<(T, RpcSubscriptionId<'static>), Error>> {
if self.close_notify.is_none() {
tracing::debug!("[Subscription::next] Closed.");
return Some(Err(Error::SubscriptionClosed(SubscriptionClosedReason::ConnectionReset.into())));
}
let raw = self.rx.next().await?;
let res = match serde_json::from_str::<SubscriptionResponse<T>>(&raw) {
Ok(r) => Ok((r.params.result, r.params.subscription.into_owned())),
Expand Down
2 changes: 2 additions & 0 deletions tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ tracing = "0.1"
serde = "1"
serde_json = "1"
hyper = { version = "0.14", features = ["http1", "client"] }
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
tokio-stream = "0.1"
58 changes: 57 additions & 1 deletion tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
use std::sync::Arc;
use std::time::Duration;

use futures::TryStreamExt;
use helpers::{http_server, http_server_with_access_control, websocket_server, websocket_server_with_subscription};
use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT};
use jsonrpsee::core::error::SubscriptionClosedReason;
use jsonrpsee::core::error::{SubscriptionClosed, SubscriptionClosedReason};
use jsonrpsee::core::{Error, JsonValue};
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::rpc_params;
use jsonrpsee::ws_client::WsClientBuilder;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;

mod helpers;

Expand Down Expand Up @@ -379,6 +382,11 @@ async fn ws_server_should_stop_subscription_after_client_drop() {

#[tokio::test]
async fn ws_server_cancels_stream_after_reset_conn() {
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");

use futures::{channel::mpsc, SinkExt, StreamExt};
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};

Expand Down Expand Up @@ -415,6 +423,54 @@ async fn ws_server_cancels_stream_after_reset_conn() {
assert_eq!(Some(()), rx.next().await, "subscription stream should be terminated after the client was dropped");
}

#[tokio::test]
async fn ws_server_subscribe_with_stream() {
use futures::StreamExt;
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};

let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());

let mut module = RpcModule::new(());

module
.register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, sink, _| {
tokio::spawn(async move {
let interval = interval(Duration::from_millis(50));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);

sink.pipe_from_stream(stream).await.unwrap();
});
Ok(())
})
.unwrap();
server.start(module).unwrap();

let client = WsClientBuilder::default().build(&server_url).await.unwrap();
let mut sub1: Subscription<usize> = client.subscribe("subscribe_5_ints", None, "unsubscribe_5_ints").await.unwrap();
let mut sub2: Subscription<usize> = client.subscribe("subscribe_5_ints", None, "unsubscribe_5_ints").await.unwrap();

let (r1, r2) = futures::future::try_join(
sub1.by_ref().take(2).try_collect::<Vec<_>>(),
sub2.by_ref().take(3).try_collect::<Vec<_>>(),
)
.await
.unwrap();

assert_eq!(r1, vec![1, 2]);
assert_eq!(r2, vec![1, 2, 3]);

// Be rude, don't run the destructor
std::mem::forget(sub2);

// sub1 is still in business, read remaining items.
assert_eq!(sub1.by_ref().take(3).try_collect::<Vec<usize>>().await.unwrap(), vec![3, 4, 5]);

let exp = SubscriptionClosed::new(SubscriptionClosedReason::Server("No close reason provided".to_string()));
// The server closed down the subscription it will send a close reason.
assert!(matches!(sub1.next().await, Some(Err(Error::SubscriptionClosed(close_reason))) if close_reason == exp));
}

#[tokio::test]
async fn ws_batch_works() {
let server_addr = websocket_server().await;
Expand Down
Loading