diff --git a/client/ws-client/src/tests.rs b/client/ws-client/src/tests.rs index b9b7a19468..c4254bd64a 100644 --- a/client/ws-client/src/tests.rs +++ b/client/ws-client/src/tests.rs @@ -191,6 +191,26 @@ async fn notification_handler_works() { } } +#[tokio::test] +async fn batched_notification_handler_works() { + let server = WebSocketTestServer::with_hardcoded_notification( + "127.0.0.1:0".parse().unwrap(), + server_batched_notification("test", "batched server originated notification works".into()), + ) + .with_default_timeout() + .await + .unwrap(); + + let uri = to_ws_uri_string(server.local_addr()); + let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap(); + { + let mut nh: Subscription = + client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap(); + let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap(); + assert_eq!("batched server originated notification works".to_owned(), response); + } +} + #[tokio::test] async fn notification_close_on_lagging() { init_logger(); diff --git a/core/src/client/async_client/mod.rs b/core/src/client/async_client/mod.rs index 73cfc0cf34..6ebb5f8827 100644 --- a/core/src/client/async_client/mod.rs +++ b/core/src/client/async_client/mod.rs @@ -767,32 +767,36 @@ fn handle_backend_messages( let mut batch = Vec::with_capacity(raw_responses.len()); let mut range = None; + let mut got_notif = false; for r in raw_responses { - let Ok(response) = serde_json::from_str::>(r.get()) else { + if let Ok(response) = serde_json::from_str::>(r.get()) { + let id = response.id.try_parse_inner_as_number()?; + let result = ResponseSuccess::try_from(response).map(|s| s.result); + batch.push(InnerBatchResponse { id, result }); + + let r = range.get_or_insert(id..id); + + if id < r.start { + r.start = id; + } + + if id > r.end { + r.end = id; + } + } else if let Ok(notif) = serde_json::from_str::>(r.get()) { + got_notif = true; + process_notification(&mut manager.lock(), notif); + } else { return Err(unparse_error(raw)); }; - - let id = response.id.try_parse_inner_as_number()?; - let result = ResponseSuccess::try_from(response).map(|s| s.result); - batch.push(InnerBatchResponse { id, result }); - - let r = range.get_or_insert(id..id); - - if id < r.start { - r.start = id; - } - - if id > r.end { - r.end = id; - } } if let Some(mut range) = range { // the range is exclusive so need to add one. range.end += 1; process_batch_response(&mut manager.lock(), batch, range)?; - } else { + } else if !got_notif { return Err(EmptyBatchRequest.into()); } } else { diff --git a/test-utils/src/helpers.rs b/test-utils/src/helpers.rs index 2dc5854731..16d7e8184d 100644 --- a/test-utils/src/helpers.rs +++ b/test-utils/src/helpers.rs @@ -186,6 +186,11 @@ pub fn server_notification(method: &str, params: Value) -> String { format!(r#"{{"jsonrpc":"2.0","method":"{}", "params":{} }}"#, method, serde_json::to_string(¶ms).unwrap()) } +/// Batched server originated notification +pub fn server_batched_notification(method: &str, params: Value) -> String { + format!(r#"[{{"jsonrpc":"2.0","method":"{}", "params":{} }}]"#, method, serde_json::to_string(¶ms).unwrap()) +} + pub async fn http_request(body: Body, uri: Uri) -> Result { let client = hyper::Client::new(); http_post(client, body, uri).await