Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ws server): batch wait until all methods has been executed. #542

Merged
merged 8 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@ criterion_group!(
SyncBencher::http_requests,
SyncBencher::batched_http_requests,
SyncBencher::websocket_requests,
// TODO: https://github.com/paritytech/jsonrpsee/issues/528
// SyncBencher::batched_ws_requests,
SyncBencher::batched_ws_requests,
);
criterion_group!(
async_benches,
AsyncBencher::http_requests,
AsyncBencher::batched_http_requests,
AsyncBencher::websocket_requests,
// TODO: https://github.com/paritytech/jsonrpsee/issues/528
// AsyncBencher::batched_ws_requests
AsyncBencher::batched_ws_requests
);
criterion_group!(subscriptions, AsyncBencher::subscriptions);
criterion_main!(types_benches, sync_benches, async_benches, subscriptions);
Expand Down
4 changes: 3 additions & 1 deletion tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ license = "MIT"
publish = false

[dev-dependencies]
env_logger = "0.8"
tracing-subscriber = "0.3.1"
tracing = "0.1"
beef = { version = "0.5.1", features = ["impl_serde"] }
futures = { version = "0.3.14", default-features = false, features = ["std"] }
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
tokio = { version = "1", features = ["full"] }
serde_json = "1"
tracing = "0.1"
7 changes: 7 additions & 0 deletions tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ pub async fn websocket_server() -> SocketAddr {
let mut module = RpcModule::new(());
module.register_method("say_hello", |_, _| Ok("hello")).unwrap();

module
.register_async_method("slow_hello", |_, _| async {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
Ok("hello")
})
.unwrap();

let addr = server.local_addr().unwrap();

server.start(module).unwrap();
Expand Down
23 changes: 23 additions & 0 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,26 @@ async fn ws_server_should_stop_subscription_after_client_drop() {
// assert that the server received `SubscriptionClosed` after the client was dropped.
assert!(matches!(rx.next().await.unwrap(), SubscriptionClosedError { .. }));
}

#[tokio::test]
async fn ws_batch_works() {
let subscriber = tracing_subscriber::FmtSubscriber::builder()
// all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.)
// will be written to stdout.
.with_max_level(tracing::Level::TRACE)
// completes the builder.
.finish();

tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");

let server_addr = websocket_server().await;
let server_url = format!("ws://{}", server_addr);
let client = Arc::new(WsClientBuilder::default().build(&server_url).await.unwrap());

let mut batch = Vec::new();

batch.push(("say_hello", rpc_params![]));
batch.push(("slow_hello", rpc_params![]));

assert!(client.batch_request::<String>(batch).await.is_ok())
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}
64 changes: 38 additions & 26 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ use crate::types::{
TEN_MB_SIZE_BYTES,
};
use futures_channel::mpsc;
use futures_util::future::FutureExt;
use futures_util::io::{BufReader, BufWriter};
use futures_util::stream::StreamExt;
use futures_util::stream::{self, StreamExt};
use soketto::handshake::{server::Response, Server as SokettoServer};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
Expand Down Expand Up @@ -296,34 +297,45 @@ async fn background_task(
}
}
Some(b'[') => {
if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&data) {
if !batch.is_empty() {
// Batch responses must be sent back as a single message so we read the results from each
// request in the batch and read the results off of a new channel, `rx_batch`, and then send the
// complete batch response back to the client over `tx`.
let (tx_batch, mut rx_batch) = mpsc::unbounded::<String>();

for fut in batch
.into_iter()
.filter_map(|req| methods.execute_with_resources(&tx_batch, req, conn_id, &resources))
{
method_executors.add(fut);
}

// Closes the receiving half of a channel without dropping it. This prevents any further
// messages from being sent on the channel.
rx_batch.close();
let results = collect_batch_response(rx_batch).await;
if let Err(err) = tx.unbounded_send(results) {
tracing::error!("Error sending batch response to the client: {:?}", err)
// Make sure the following variables are not moved into async closure below.
let d = std::mem::take(&mut data);
let resources = &resources;
let methods = &methods;
let tx2 = tx.clone();

let fut = async move {
// Batch responses must be sent back as a single message so we read the results from each
// request in the batch and read the results off of a new channel, `rx_batch`, and then send the
// complete batch response back to the client over `tx`.
let (tx_batch, mut rx_batch) = mpsc::unbounded();
if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&d) {
if !batch.is_empty() {
let methods_stream =
stream::iter(batch.into_iter().filter_map(|req| {
methods.execute_with_resources(&tx_batch, req, conn_id, resources)
}));

let results = methods_stream
.for_each_concurrent(None, |item| item)
.then(|_| {
rx_batch.close();
collect_batch_response(rx_batch)
})
.await;

if let Err(err) = tx2.unbounded_send(results) {
tracing::error!("Error sending batch response to the client: {:?}", err)
}
} else {
send_error(Id::Null, &tx2, ErrorCode::InvalidRequest.into());
}
} else {
send_error(Id::Null, &tx, ErrorCode::InvalidRequest.into());
let (id, code) = prepare_error(&d);
send_error(id, &tx2, code.into());
}
} else {
let (id, code) = prepare_error(&data);
send_error(id, &tx, code.into());
}
};

method_executors.add(Box::pin(fut));
}
_ => send_error(Id::Null, &tx, ErrorCode::ParseError.into()),
}
Expand Down