From c5e3a61f8f5eea4267a9973022bebd5a360f7b17 Mon Sep 17 00:00:00 2001 From: FabijanC Date: Fri, 8 Nov 2024 12:44:09 +0100 Subject: [PATCH] Add WebSocket RPC methods: subscribeNewHeads, unsubscribe (#634) * Do interval block creation via request --- crates/starknet-devnet-core/src/blocks/mod.rs | 8 + crates/starknet-devnet-server/Cargo.toml | 2 +- .../src/api/json_rpc/endpoints.rs | 2 +- .../src/api/json_rpc/endpoints_ws.rs | 111 +++++ .../src/api/json_rpc/error.rs | 14 + .../src/api/json_rpc/mod.rs | 231 ++++++++--- .../src/api/json_rpc/models.rs | 6 + crates/starknet-devnet-server/src/api/mod.rs | 9 +- crates/starknet-devnet-server/src/lib.rs | 1 + .../src/rpc_core/response.rs | 10 +- .../starknet-devnet-server/src/subscribe.rs | 159 ++++++++ crates/starknet-devnet/Cargo.toml | 2 +- crates/starknet-devnet/src/main.rs | 33 +- crates/starknet-devnet/tests/common/utils.rs | 65 +++ .../tests/test_subscription_to_blocks.rs | 381 ++++++++++++++++++ .../starknet-devnet/tests/test_websocket.rs | 66 ++- website/docs/api.md | 6 +- 17 files changed, 973 insertions(+), 133 deletions(-) create mode 100644 crates/starknet-devnet-server/src/api/json_rpc/endpoints_ws.rs create mode 100644 crates/starknet-devnet-server/src/subscribe.rs create mode 100644 crates/starknet-devnet/tests/test_subscription_to_blocks.rs diff --git a/crates/starknet-devnet-core/src/blocks/mod.rs b/crates/starknet-devnet-core/src/blocks/mod.rs index 41e0b6795..3c033ad52 100644 --- a/crates/starknet-devnet-core/src/blocks/mod.rs +++ b/crates/starknet-devnet-core/src/blocks/mod.rs @@ -266,6 +266,14 @@ impl StarknetBlock { } } + pub fn create_empty_accepted() -> Self { + Self { + header: BlockHeader::default(), + transaction_hashes: vec![], + status: BlockStatus::AcceptedOnL2, + } + } + pub(crate) fn set_block_number(&mut self, block_number: u64) { self.header.block_number = BlockNumber(block_number) } diff --git a/crates/starknet-devnet-server/Cargo.toml b/crates/starknet-devnet-server/Cargo.toml index c56739704..a76dadda9 100644 --- a/crates/starknet-devnet-server/Cargo.toml +++ b/crates/starknet-devnet-server/Cargo.toml @@ -31,6 +31,7 @@ thiserror = { workspace = true } anyhow = { workspace = true } lazy_static = { workspace = true } enum-helper-macros = { workspace = true } +rand = { workspace = true } # devnet starknet-core = { workspace = true } @@ -38,7 +39,6 @@ starknet-types = { workspace = true } starknet-rs-core = { workspace = true } [dev-dependencies] -rand = { workspace = true } rand_chacha = { workspace = true } regex_generate = { workspace = true } serde_yaml = { workspace = true } diff --git a/crates/starknet-devnet-server/src/api/json_rpc/endpoints.rs b/crates/starknet-devnet-server/src/api/json_rpc/endpoints.rs index dd0d9ec9c..57e052d8b 100644 --- a/crates/starknet-devnet-server/src/api/json_rpc/endpoints.rs +++ b/crates/starknet-devnet-server/src/api/json_rpc/endpoints.rs @@ -23,7 +23,7 @@ use crate::api::http::endpoints::DevnetConfig; const DEFAULT_CONTINUATION_TOKEN: &str = "0"; -/// here are the definitions and stub implementations of all JSON-RPC read endpoints +/// The definitions of JSON-RPC read endpoints defined in starknet_api_openrpc.json impl JsonRpcHandler { /// starknet_specVersion pub fn spec_version(&self) -> StrictRpcResult { diff --git a/crates/starknet-devnet-server/src/api/json_rpc/endpoints_ws.rs b/crates/starknet-devnet-server/src/api/json_rpc/endpoints_ws.rs new file mode 100644 index 000000000..d8ed1b635 --- /dev/null +++ b/crates/starknet-devnet-server/src/api/json_rpc/endpoints_ws.rs @@ -0,0 +1,111 @@ +use starknet_core::error::Error; +use starknet_rs_core::types::{BlockId, BlockTag}; +use starknet_types::starknet_api::block::BlockStatus; + +use super::error::ApiError; +use super::models::{BlockIdInput, SubscriptionIdInput}; +use super::{JsonRpcHandler, JsonRpcSubscriptionRequest}; +use crate::rpc_core::request::Id; +use crate::subscribe::{SocketId, SubscriptionNotification}; + +/// The definitions of JSON-RPC read endpoints defined in starknet_ws_api.json +impl JsonRpcHandler { + pub async fn execute_ws( + &self, + request: JsonRpcSubscriptionRequest, + rpc_request_id: Id, + socket_id: SocketId, + ) -> Result<(), ApiError> { + match request { + JsonRpcSubscriptionRequest::NewHeads(data) => { + self.subscribe_new_heads(data, rpc_request_id, socket_id).await + } + JsonRpcSubscriptionRequest::TransactionStatus => todo!(), + JsonRpcSubscriptionRequest::PendingTransactions => todo!(), + JsonRpcSubscriptionRequest::Events => todo!(), + JsonRpcSubscriptionRequest::Unsubscribe(SubscriptionIdInput { subscription_id }) => { + let mut sockets = self.api.sockets.lock().await; + let socket_context = sockets.get_mut(&socket_id).ok_or( + ApiError::StarknetDevnetError(Error::UnexpectedInternalError { + msg: format!("Unregistered socket ID: {socket_id}"), + }), + )?; + + socket_context.unsubscribe(rpc_request_id, subscription_id).await?; + Ok(()) + } + } + } + + /// starknet_subscribeNewHeads + /// Checks if an optional block ID is provided. Validates that the block exists and is not too + /// many blocks in the past. If it is a valid block, the user is notified of all blocks from the + /// old up to the latest, and subscribed to new ones. If no block ID specified, the user is just + /// subscribed to new blocks. + pub async fn subscribe_new_heads( + &self, + block_id_input: Option, + rpc_request_id: Id, + socket_id: SocketId, + ) -> Result<(), ApiError> { + let latest_tag = BlockId::Tag(BlockTag::Latest); + let block_id = if let Some(BlockIdInput { block_id }) = block_id_input { + block_id.into() + } else { + // if no block ID input, this eventually just subscribes the user to new blocks + latest_tag + }; + + let starknet = self.api.starknet.lock().await; + + // checking the block's existence; aborted blocks treated as not found + let query_block = match starknet.get_block(&block_id) { + Ok(block) => match block.status() { + BlockStatus::Rejected => Err(ApiError::BlockNotFound), + _ => Ok(block), + }, + Err(Error::NoBlock) => Err(ApiError::BlockNotFound), + Err(other) => Err(ApiError::StarknetDevnetError(other)), + }?; + + let latest_block = starknet.get_block(&latest_tag)?; + + let query_block_number = query_block.block_number().0; + let latest_block_number = latest_block.block_number().0; + + let blocks_back_amount = if query_block_number > latest_block_number { + 0 + } else { + latest_block_number - query_block_number + }; + + if blocks_back_amount > 1024 { + return Err(ApiError::TooManyBlocksBack); + } + + // perform the actual subscription + let mut sockets = self.api.sockets.lock().await; + let socket_context = sockets.get_mut(&socket_id).ok_or(ApiError::StarknetDevnetError( + Error::UnexpectedInternalError { msg: format!("Unregistered socket ID: {socket_id}") }, + ))?; + let subscription_id = socket_context.subscribe(rpc_request_id).await; + + if let BlockId::Tag(_) = block_id { + // if the specified block ID is a tag (i.e. latest/pending), no old block handling + return Ok(()); + } + + // Notifying of old blocks. latest_block_number inclusive? + // Yes, only if block_id != latest/pending (handled above) + for block_n in query_block_number..=latest_block_number { + let old_block = starknet + .get_block(&BlockId::Number(block_n)) + .map_err(ApiError::StarknetDevnetError)?; + + let notification = SubscriptionNotification::NewHeadsNotification(old_block.into()); + socket_context.notify(subscription_id, notification).await; + } + + Ok(()) + } +} diff --git a/crates/starknet-devnet-server/src/api/json_rpc/error.rs b/crates/starknet-devnet-server/src/api/json_rpc/error.rs index 1b119da65..b5ecaea69 100644 --- a/crates/starknet-devnet-server/src/api/json_rpc/error.rs +++ b/crates/starknet-devnet-server/src/api/json_rpc/error.rs @@ -62,6 +62,10 @@ pub enum ApiError { HttpApiError(#[from] HttpApiError), #[error("the compiled class hash did not match the one supplied in the transaction")] CompiledClassHashMismatch, + #[error("Cannot go back more than 1024 blocks")] + TooManyBlocksBack, + #[error("Invalid subscription id")] + InvalidSubscriptionId, } impl ApiError { @@ -205,6 +209,16 @@ impl ApiError { data: None, }, ApiError::HttpApiError(http_api_error) => http_api_error.http_api_error_to_rpc_error(), + ApiError::TooManyBlocksBack => RpcError { + code: crate::rpc_core::error::ErrorCode::ServerError(68), + message: error_message.into(), + data: None, + }, + ApiError::InvalidSubscriptionId => RpcError { + code: crate::rpc_core::error::ErrorCode::ServerError(66), + message: error_message.into(), + data: None, + }, } } } diff --git a/crates/starknet-devnet-server/src/api/json_rpc/mod.rs b/crates/starknet-devnet-server/src/api/json_rpc/mod.rs index cf3fc4d87..52bce8b61 100644 --- a/crates/starknet-devnet-server/src/api/json_rpc/mod.rs +++ b/crates/starknet-devnet-server/src/api/json_rpc/mod.rs @@ -1,4 +1,5 @@ mod endpoints; +mod endpoints_ws; pub mod error; pub mod models; pub(crate) mod origin_forwarder; @@ -8,19 +9,23 @@ mod write_endpoints; pub const RPC_SPEC_VERSION: &str = "0.7.1"; +use std::sync::Arc; + use axum::extract::ws::{Message, WebSocket}; use enum_helper_macros::{AllVariantsSerdeRenames, VariantName}; -use futures::StreamExt; +use futures::stream::SplitSink; +use futures::{SinkExt, StreamExt}; use models::{ BlockAndClassHashInput, BlockAndContractAddressInput, BlockAndIndexInput, CallInput, - EstimateFeeInput, EventsInput, GetStorageInput, L1TransactionHashInput, TransactionHashInput, - TransactionHashOutput, + EstimateFeeInput, EventsInput, GetStorageInput, L1TransactionHashInput, SubscriptionIdInput, + TransactionHashInput, TransactionHashOutput, }; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use serde_json::json; use starknet_core::starknet::starknet_config::{DumpOn, StarknetConfig}; -use starknet_core::CasmContractClass; -use starknet_rs_core::types::{ContractClass as CodegenContractClass, Felt}; +use starknet_core::{CasmContractClass, StarknetBlock}; +use starknet_rs_core::types::{BlockId, BlockTag, ContractClass as CodegenContractClass, Felt}; use starknet_types::messaging::{MessageToL1, MessageToL2}; use starknet_types::rpc::block::{Block, PendingBlock}; use starknet_types::rpc::estimate_message_fee::{ @@ -34,6 +39,7 @@ use starknet_types::rpc::transactions::{ TransactionStatus, TransactionTrace, TransactionWithHash, }; use starknet_types::starknet_api::block::BlockNumber; +use tokio::sync::Mutex; use tracing::{error, info, trace}; use self::error::StrictRpcResult; @@ -61,10 +67,11 @@ use crate::api::json_rpc::models::{ use crate::api::serde_helpers::{empty_params, optional_params}; use crate::dump_util::dump_event; use crate::restrictive_mode::is_json_rpc_method_restricted; -use crate::rpc_core::error::RpcError; +use crate::rpc_core::error::{ErrorCode, RpcError}; use crate::rpc_core::request::RpcMethodCall; use crate::rpc_core::response::{ResponseResult, RpcResponse}; use crate::rpc_handler::RpcHandler; +use crate::subscribe::{SocketContext, SocketId, SubscriptionNotification}; use crate::ServerConfig; /// Helper trait to easily convert results to rpc results @@ -123,55 +130,44 @@ impl RpcHandler for JsonRpcHandler { async fn on_call(&self, call: RpcMethodCall) -> RpcResponse { trace!(target: "rpc", id = ?call.id , method = ?call.method, "received method call"); - let RpcMethodCall { method, params, id, .. } = call.clone(); - let params: serde_json::Value = params.into(); - let deserializable_call = serde_json::json!({ - "method": &method, - "params": params - }); + if !self.allows_method(&call.method) { + return RpcResponse::from_rpc_error(RpcError::new(ErrorCode::MethodForbidden), call.id); + } - match serde_json::from_value::(deserializable_call) { + match to_json_rpc_request(&call) { Ok(req) => { - if let Some(restricted_methods) = &self.server_config.restricted_methods { - if is_json_rpc_method_restricted(&method, restricted_methods) { - return RpcResponse::new( - id, - RpcError::new(crate::rpc_core::error::ErrorCode::MethodForbidden), - ); - } - } - let result = self.on_request(req, call).await; - RpcResponse::new(id, result) - } - Err(err) => { - let err = err.to_string(); - // since JSON-RPC specification requires returning a Method Not Found error, - // we apply a hacky way to induce this - checking the stringified error message - let distinctive_error = format!("unknown variant `{method}`"); - if err.contains(&distinctive_error) { - error!(target: "rpc", ?method, "failed to deserialize method due to unknown variant"); - RpcResponse::new(id, RpcError::method_not_found()) - } else { - error!(target: "rpc", ?method, ?err, "failed to deserialize method"); - RpcResponse::new(id, RpcError::invalid_params(err)) - } + let result = self.on_request(req, call.clone()).await; + RpcResponse::new(call.id, result) } + Err(e) => RpcResponse::from_rpc_error(e, call.id), } } - async fn on_websocket(&self, mut socket: WebSocket) { - while let Some(msg) = socket.next().await { + async fn on_websocket(&self, socket: WebSocket) { + let (socket_writer, mut socket_reader) = socket.split(); + let socket_writer = Arc::new(Mutex::new(socket_writer)); + + let socket_id = rand::random(); + self.api + .sockets + .lock() + .await + .insert(socket_id, SocketContext::from_sender(socket_writer.clone())); + + // listen to new messages coming through the socket + let mut socket_safely_closed = false; + while let Some(msg) = socket_reader.next().await { match msg { Ok(Message::Text(text)) => { - self.on_websocket_rpc_call(text.as_bytes(), &mut socket).await; + self.on_websocket_call(text.as_bytes(), socket_writer.clone(), socket_id).await; } Ok(Message::Binary(bytes)) => { - self.on_websocket_rpc_call(&bytes, &mut socket).await; + self.on_websocket_call(&bytes, socket_writer.clone(), socket_id).await; } Ok(Message::Close(_)) => { - tracing::info!("Websocket disconnected"); - return; + socket_safely_closed = true; + break; } other => { tracing::error!("Socket handler got an unexpected message: {other:?}") @@ -179,7 +175,12 @@ impl RpcHandler for JsonRpcHandler { } } - tracing::error!("Failed socket read"); + if socket_safely_closed { + self.api.sockets.lock().await.remove(&socket_id); + tracing::info!("Websocket disconnected"); + } else { + tracing::error!("Failed socket read"); + } } } @@ -205,6 +206,40 @@ impl JsonRpcHandler { } } + /// The latest block is always defined, so to avoid having to deal with Err/None in places where + /// this method is called, it is defined to return an empty accepted block, even though that + /// case should never happen. + async fn get_latest_block(&self) -> StarknetBlock { + let starknet = self.api.starknet.lock().await; + match starknet.get_block(&BlockId::Tag(BlockTag::Latest)) { + Ok(block) => block.clone(), + Err(_) => StarknetBlock::create_empty_accepted(), + } + } + + async fn broadcast_changes(&self, old_latest_block: StarknetBlock) { + let new_latest_block = self.get_latest_block().await; + + let old_block_number = old_latest_block.block_number().0; + let new_block_number = new_latest_block.block_number().0; + + if new_block_number > old_block_number { + let mut sockets = self.api.sockets.lock().await; + for (_, socket_context) in sockets.iter_mut() { + socket_context + .notify_subscribers(SubscriptionNotification::NewHeadsNotification( + (&new_latest_block).into(), + )) + .await; + } + } else { + // TODO - possible only if an immutable request came or one of the following happened: + // blocks aborted, devnet restarted, devnet loaded. Or should loading cause websockets + // to be restarted too, thus not requiring notification? + tracing::debug!("Nothing happened worthy of a new block notification") + } + } + /// The method matches the request to the corresponding enum variant and executes the request async fn execute( &self, @@ -213,6 +248,9 @@ impl JsonRpcHandler { ) -> ResponseResult { trace!(target: "JsonRpcHandler::execute", "executing starknet request"); + // for later comparison and subscription notifications + let old_latest_block = self.get_latest_block().await; + // true if origin should be tried after request fails; relevant in forking mode let mut forwardable = true; @@ -370,6 +408,10 @@ impl JsonRpcHandler { } } + // TODO if request.modifies_state() { ... } - also in the beginning of this method to avoid + // unnecessary lock acquiring + self.broadcast_changes(old_latest_block).await; + if starknet_resp.is_ok() { if let Err(e) = self.update_dump(&original_call).await { return ResponseResult::Error(e); @@ -380,26 +422,49 @@ impl JsonRpcHandler { } /// Takes `bytes` to be an encoded RPC call, executes it, and sends the response back via `ws`. - async fn on_websocket_rpc_call(&self, bytes: &[u8], ws: &mut WebSocket) { - match serde_json::from_slice(bytes) { - Ok(call) => { - let resp = self.on_call(call).await; - let resp_serialized = serde_json::to_string(&resp).unwrap_or_else(|e| { - let err_msg = format!("Error converting RPC response to string: {e}"); - tracing::error!(err_msg); - err_msg - }); - - if let Err(e) = ws.send(Message::Text(resp_serialized)).await { - tracing::error!("Error sending websocket message: {e}"); - } - } - Err(e) => { - if let Err(e) = ws.send(Message::Text(e.to_string())).await { - tracing::error!("Error sending websocket message: {e}"); - } + async fn on_websocket_call( + &self, + bytes: &[u8], + ws: Arc>>, + socket_id: SocketId, + ) { + let error_serialized = match serde_json::from_slice(bytes) { + Ok(rpc_call) => match self.on_websocket_rpc_call(&rpc_call, socket_id).await { + Ok(_) => return, + Err(e) => json!(RpcResponse::from_rpc_error(e, rpc_call.id)).to_string(), + }, + Err(e) => e.to_string(), + }; + + if let Err(e) = ws.lock().await.send(Message::Text(error_serialized)).await { + tracing::error!("Error sending websocket message: {e}"); + } + } + + fn allows_method(&self, method: &String) -> bool { + if let Some(restricted_methods) = &self.server_config.restricted_methods { + if is_json_rpc_method_restricted(method, restricted_methods) { + return false; } } + + true + } + + /// Since some subscriptions might need to send multiple messages, sending messages other than + /// errors is left to individual RPC method handlers and this method returns an empty successful + /// Result. + async fn on_websocket_rpc_call( + &self, + call: &RpcMethodCall, + socket_id: SocketId, + ) -> Result<(), RpcError> { + trace!(target: "rpc", id = ?call.id , method = ?call.method, "received websocket call"); + + let req = to_json_rpc_request(call)?; + self.execute_ws(req, call.id.clone(), socket_id) + .await + .map_err(|e| e.api_error_to_rpc_error()) } const DUMPABLE_METHODS: &'static [&'static str] = &[ @@ -567,6 +632,48 @@ pub enum JsonRpcRequest { #[serde(rename = "devnet_getConfig", with = "empty_params")] DevnetConfig, } + +#[derive(Deserialize, AllVariantsSerdeRenames, VariantName)] +#[cfg_attr(test, derive(Debug))] +#[serde(tag = "method", content = "params")] +pub enum JsonRpcSubscriptionRequest { + #[serde(rename = "starknet_subscribeNewHeads", with = "optional_params")] + NewHeads(Option), + #[serde(rename = "starknet_subscribeTransactionStatus")] + TransactionStatus, + #[serde(rename = "starknet_subscribePendingTransactions")] + PendingTransactions, + #[serde(rename = "starknet_subscribeEvents")] + Events, + #[serde(rename = "starknet_unsubscribe")] + Unsubscribe(SubscriptionIdInput), +} + +fn to_json_rpc_request(call: &RpcMethodCall) -> Result +where + D: DeserializeOwned, +{ + let params: serde_json::Value = call.params.clone().into(); + let deserializable_call = json!({ + "method": call.method, + "params": params + }); + + serde_json::from_value::(deserializable_call).map_err(|err| { + let err = err.to_string(); + // since JSON-RPC specification requires returning a Method Not Found error, + // we apply a hacky way to induce this - checking the stringified error message + let distinctive_error = format!("unknown variant `{}`", call.method); + if err.contains(&distinctive_error) { + error!(target: "rpc", method = ?call.method, "failed to deserialize method due to unknown variant"); + RpcError::method_not_found() + } else { + error!(target: "rpc", method = ?call.method, ?err, "failed to deserialize method"); + RpcError::invalid_params(err) + } + }) +} + impl std::fmt::Display for JsonRpcRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.variant_name()) @@ -1246,7 +1353,7 @@ mod requests_tests { let RpcMethodCall { method, params, .. } = serde_json::from_value(json_rpc_object).unwrap(); let params: serde_json::Value = params.into(); - let deserializable_call = serde_json::json!({ + let deserializable_call = json!({ "method": &method, "params": params }); diff --git a/crates/starknet-devnet-server/src/api/json_rpc/models.rs b/crates/starknet-devnet-server/src/api/json_rpc/models.rs index 485515d17..cf5734631 100644 --- a/crates/starknet-devnet-server/src/api/json_rpc/models.rs +++ b/crates/starknet-devnet-server/src/api/json_rpc/models.rs @@ -178,6 +178,12 @@ pub struct L1TransactionHashInput { pub transaction_hash: Hash256, } +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(deny_unknown_fields)] +pub struct SubscriptionIdInput { + pub subscription_id: i64, +} + #[cfg(test)] mod tests { use starknet_rs_core::types::{BlockId as ImportedBlockId, BlockTag, Felt}; diff --git a/crates/starknet-devnet-server/src/api/mod.rs b/crates/starknet-devnet-server/src/api/mod.rs index 59d133dcc..6c7426e8e 100644 --- a/crates/starknet-devnet-server/src/api/mod.rs +++ b/crates/starknet-devnet-server/src/api/mod.rs @@ -2,12 +2,14 @@ pub mod http; pub mod json_rpc; pub mod serde_helpers; +use std::collections::HashMap; use std::sync::Arc; use starknet_core::starknet::Starknet; use tokio::sync::Mutex; use crate::dump_util::DumpEvent; +use crate::subscribe::{SocketContext, SocketId}; /// Data that can be shared between threads with read write lock access /// Whatever needs to be accessed as information outside of Starknet could be added to this struct @@ -16,10 +18,15 @@ pub struct Api { // maybe the config should be added here next to the starknet instance pub starknet: Arc>, pub dumpable_events: Arc>>, + pub sockets: Arc>>, } impl Api { pub fn new(starknet: Starknet) -> Self { - Self { starknet: Arc::new(Mutex::new(starknet)), dumpable_events: Default::default() } + Self { + starknet: Arc::new(Mutex::new(starknet)), + dumpable_events: Default::default(), + sockets: Arc::new(Mutex::new(HashMap::new())), + } } } diff --git a/crates/starknet-devnet-server/src/lib.rs b/crates/starknet-devnet-server/src/lib.rs index 1057e083d..6adce0160 100644 --- a/crates/starknet-devnet-server/src/lib.rs +++ b/crates/starknet-devnet-server/src/lib.rs @@ -7,6 +7,7 @@ pub mod rpc_core; /// handlers for axum server pub mod rpc_handler; pub mod server; +pub mod subscribe; #[cfg(any(test, feature = "test_utils"))] pub mod test_utils; diff --git a/crates/starknet-devnet-server/src/rpc_core/response.rs b/crates/starknet-devnet-server/src/rpc_core/response.rs index ce663538a..d3b2a3a4a 100644 --- a/crates/starknet-devnet-server/src/rpc_core/response.rs +++ b/crates/starknet-devnet-server/src/rpc_core/response.rs @@ -15,12 +15,6 @@ pub struct RpcResponse { pub(crate) result: ResponseResult, } -impl From for RpcResponse { - fn from(e: RpcError) -> Self { - Self { jsonrpc: Version::V2, id: None, result: ResponseResult::Error(e) } - } -} - impl RpcResponse { pub fn new(id: Id, content: impl Into) -> Self { RpcResponse { jsonrpc: Version::V2, id: Some(id), result: content.into() } @@ -29,6 +23,10 @@ impl RpcResponse { pub fn invalid_request(id: Id) -> Self { Self::new(id, RpcError::invalid_request()) } + + pub fn from_rpc_error(e: RpcError, id: Id) -> Self { + Self { jsonrpc: Version::V2, id: Some(id), result: ResponseResult::Error(e) } + } } /// Represents the result of a call either success or error diff --git a/crates/starknet-devnet-server/src/subscribe.rs b/crates/starknet-devnet-server/src/subscribe.rs new file mode 100644 index 000000000..ed72fa199 --- /dev/null +++ b/crates/starknet-devnet-server/src/subscribe.rs @@ -0,0 +1,159 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use axum::extract::ws::{Message, WebSocket}; +use futures::stream::SplitSink; +use futures::SinkExt; +use serde::{self, Serialize}; +use starknet_types::rpc::block::BlockHeader; +use tokio::sync::Mutex; + +use crate::api::json_rpc::error::ApiError; +use crate::rpc_core::request::Id; + +pub type SocketId = u64; + +type SubscriptionId = i64; + +#[derive(Debug)] +pub enum Subscription { + NewHeads, + TransactionStatus, + PendingTransactions, + Events, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub enum SubscriptionConfirmation { + NewHeadsConfirmation(SubscriptionId), + TransactionStatusConfirmation(SubscriptionId), + PendingTransactionsConfirmation(SubscriptionId), + EventsConfirmation(SubscriptionId), + UnsubscriptionConfirmation(bool), +} + +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub enum SubscriptionNotification { + NewHeadsNotification(BlockHeader), + // TransactionStatusNotification, + // PendingTransactionsNotification, + // EventsNotification, +} + +impl SubscriptionNotification { + fn method_name(&self) -> &'static str { + match self { + SubscriptionNotification::NewHeadsNotification(_) => "starknet_subscriptionNewHeads", + // SubscriptionNotification::TransactionStatusNotification => { + // "starknet_subscriptionTransactionStatus" + // } + // SubscriptionNotification::PendingTransactionsNotification => { + // "starknet_subscriptionPendingTransactions" + // } + // SubscriptionNotification::EventsNotification => "starknet_subscriptionEvents", + } + } +} + +#[derive(Debug, Clone)] +pub enum SubscriptionResponse { + Confirmation { rpc_request_id: Id, result: SubscriptionConfirmation }, + Notification { subscription_id: SubscriptionId, data: Box }, +} + +impl SubscriptionResponse { + fn to_serialized_rpc_response(&self) -> serde_json::Value { + let mut resp = match self { + SubscriptionResponse::Confirmation { rpc_request_id, result } => { + serde_json::json!({ + "id": rpc_request_id, + "result": result, + }) + } + SubscriptionResponse::Notification { subscription_id, data } => { + serde_json::json!({ + "method": data.method_name(), + "params": { + "subscription_id": subscription_id, + "result": data, + } + }) + } + }; + + resp["jsonrpc"] = "2.0".into(); + resp + } +} + +pub struct SocketContext { + /// The sender part of the socket's own channel + sender: Arc>>, + subscriptions: HashMap, +} + +impl SocketContext { + pub fn from_sender(sender: Arc>>) -> Self { + Self { sender, subscriptions: HashMap::new() } + } + + async fn send(&self, subscription_response: SubscriptionResponse) { + let resp_serialized = subscription_response.to_serialized_rpc_response().to_string(); + + if let Err(e) = self.sender.lock().await.send(Message::Text(resp_serialized)).await { + tracing::error!("Failed writing to socket: {}", e.to_string()); + } + } + + pub async fn subscribe(&mut self, rpc_request_id: Id) -> SubscriptionId { + let subscription_id = rand::random(); + self.subscriptions.insert(subscription_id, Subscription::NewHeads); + + self.send(SubscriptionResponse::Confirmation { + rpc_request_id, + result: SubscriptionConfirmation::NewHeadsConfirmation(subscription_id), + }) + .await; + + subscription_id + } + + pub async fn unsubscribe( + &mut self, + rpc_request_id: Id, + subscription_id: SubscriptionId, + ) -> Result<(), ApiError> { + match self.subscriptions.remove(&subscription_id) { + Some(_) => { + self.send(SubscriptionResponse::Confirmation { + rpc_request_id, + result: SubscriptionConfirmation::UnsubscriptionConfirmation(true), + }) + .await; + Ok(()) + } + None => Err(ApiError::InvalidSubscriptionId), + } + } + + pub async fn notify(&self, subscription_id: SubscriptionId, data: SubscriptionNotification) { + self.send(SubscriptionResponse::Notification { subscription_id, data: Box::new(data) }) + .await; + } + + pub async fn notify_subscribers(&self, data: SubscriptionNotification) { + for (subscription_id, subscription) in self.subscriptions.iter() { + match subscription { + Subscription::NewHeads => { + // The next line is here to cause a compilation error when new enum variants are + // added. Then, use `if let`. + let SubscriptionNotification::NewHeadsNotification(_) = data; + self.notify(*subscription_id, data.clone()).await; + } + other => todo!("Unsupported subscription: {other:?}"), + } + } + } +} diff --git a/crates/starknet-devnet/Cargo.toml b/crates/starknet-devnet/Cargo.toml index 0269fdba9..131973bb2 100644 --- a/crates/starknet-devnet/Cargo.toml +++ b/crates/starknet-devnet/Cargo.toml @@ -36,6 +36,7 @@ serde_json = { workspace = true } serde = { workspace = true } anyhow = { workspace = true } starknet-rs-providers = { workspace = true } +reqwest = { workspace = true } [dev-dependencies] async-trait = { workspace = true } @@ -50,7 +51,6 @@ starknet-rs-core = { workspace = true } starknet-rs-accounts = { workspace = true } axum = { workspace = true } usc = { workspace = true } -reqwest = { workspace = true } criterion = { workspace = true } serial_test = { workspace = true } tokio-tungstenite = { workspace = true } diff --git a/crates/starknet-devnet/src/main.rs b/crates/starknet-devnet/src/main.rs index 08a352ad0..60cf21130 100644 --- a/crates/starknet-devnet/src/main.rs +++ b/crates/starknet-devnet/src/main.rs @@ -6,11 +6,11 @@ use clap::Parser; use cli::Args; use futures::future::join_all; use serde::de::IntoDeserializer; +use serde_json::json; use server::api::http::HttpApiHandler; use server::api::json_rpc::{JsonRpcHandler, RPC_SPEC_VERSION}; use server::api::Api; -use server::dump_util::{dump_events, load_events, DumpEvent}; -use server::rpc_core::request::{Id, RequestParams, Version}; +use server::dump_util::{dump_events, load_events}; use server::server::serve_http_api_json_rpc; use starknet_core::account::Account; use starknet_core::constants::{ @@ -38,7 +38,7 @@ use tokio::signal::unix::{signal, SignalKind}; use tokio::signal::windows::ctrl_c; use tokio::task::{self}; use tokio::time::{interval, sleep}; -use tracing::{info, warn}; +use tracing::{error, info, warn}; use tracing_subscriber::EnvFilter; mod cli; @@ -310,7 +310,8 @@ async fn main() -> Result<(), anyhow::Error> { if let BlockGenerationOn::Interval(seconds) = starknet_config.block_generation_on { // use JoinHandle to run block interval creation as a task - let block_interval_handle = task::spawn(create_block_interval(api.clone(), seconds)); + let full_address = format!("http://{address}"); + let block_interval_handle = task::spawn(create_block_interval(seconds, full_address)); tasks.push(block_interval_handle); } @@ -336,8 +337,8 @@ async fn main() -> Result<(), anyhow::Error> { #[allow(clippy::expect_used)] async fn create_block_interval( - api: Api, block_interval_seconds: u64, + devnet_address: String, ) -> Result<(), std::io::Error> { #[cfg(unix)] let mut sigint = { signal(SignalKind::interrupt()).expect("Failed to setup SIGINT handler") }; @@ -348,21 +349,21 @@ async fn create_block_interval( Box::pin(ctrl_c_signal) }; + let devnet_client = reqwest::Client::new(); + let block_req_body = json!({ "jsonrpc": "2.0", "id": 0, "method": "devnet_createBlock" }); + + // avoid creating block instantly after startup + sleep(Duration::from_secs(block_interval_seconds)).await; + let mut interval = interval(Duration::from_secs(block_interval_seconds)); loop { - // TODO does this need to be inside of the loop? or outside? - // avoid creating block instantly after startup - sleep(Duration::from_secs(block_interval_seconds)).await; - tokio::select! { _ = interval.tick() => { - let mut starknet = api.starknet.lock().await; - let mut dumpable_events = api.dumpable_events.lock().await; - info!("Generating block on time interval"); - - // manually add event for dumping; alternative: create a client and send request - starknet.create_block().map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - dumpable_events.push(DumpEvent { jsonrpc: Version::V2, method: "devnet_createBlock".into(), params: RequestParams::None, id: Id::Number(0) }); + // By sending a request, we take care of: 1) dumping 2) notifying subscribers + match devnet_client.post(&devnet_address).json(&block_req_body).send().await { + Ok(_) => info!("Generating block on time interval"), + Err(e) => error!("Failed block creation on time interval: {e:?}") + } } _ = sigint.recv() => { return Ok(()) diff --git a/crates/starknet-devnet/tests/common/utils.rs b/crates/starknet-devnet/tests/common/utils.rs index c58f02d58..6841bac4b 100644 --- a/crates/starknet-devnet/tests/common/utils.rs +++ b/crates/starknet-devnet/tests/common/utils.rs @@ -3,8 +3,11 @@ use std::fs; use std::path::Path; use std::process::{Child, Command}; use std::sync::Arc; +use std::time::Duration; use ethers::types::U256; +use futures::{SinkExt, StreamExt, TryStreamExt}; +use serde_json::json; use server::test_utils::assert_contains; use starknet_core::constants::CAIRO_1_ACCOUNT_CONTRACT_SIERRA_HASH; use starknet_core::random_number_generator::generate_u32_random_number; @@ -25,6 +28,8 @@ use starknet_rs_providers::{JsonRpcClient, Provider, ProviderError}; use starknet_rs_signers::LocalWallet; use starknet_types::compile_sierra_contract_json; use starknet_types::felt::felt_from_prefixed_hex; +use tokio::net::TcpStream; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use super::background_devnet::BackgroundDevnet; use super::constants::{ARGENT_ACCOUNT_CLASS_HASH, CAIRO_1_CONTRACT_PATH}; @@ -387,6 +392,66 @@ pub fn assert_json_rpc_errors_equal(e1: JsonRpcError, e2: JsonRpcError) { assert_eq!((e1.code, e1.message, e1.data), (e2.code, e2.message, e2.data)); } +pub async fn send_text_rpc_via_ws( + ws: &mut WebSocketStream>, + method: &str, + params: serde_json::Value, +) -> Result { + let text_body = json!({ + "jsonrpc": "2.0", + "id": 0, + "method": method, + "params": params, + }) + .to_string(); + ws.send(tokio_tungstenite::tungstenite::Message::Text(text_body)).await?; + + let resp_raw = + ws.next().await.ok_or(anyhow::Error::msg("No response in websocket stream"))??; + let resp_body: serde_json::Value = serde_json::from_slice(&resp_raw.into_data())?; + + Ok(resp_body) +} + +pub async fn send_binary_rpc_via_ws( + ws: &mut WebSocketStream>, + method: &str, + params: serde_json::Value, +) -> Result { + let body = json!({ + "jsonrpc": "2.0", + "id": 0, + "method": method, + "params": params, + }); + let binary_body = serde_json::to_vec(&body)?; + ws.send(tokio_tungstenite::tungstenite::Message::Binary(binary_body)).await?; + + let resp_raw = + ws.next().await.ok_or(anyhow::Error::msg("No response in websocket stream"))??; + let resp_body: serde_json::Value = serde_json::from_slice(&resp_raw.into_data())?; + + Ok(resp_body) +} + +/// Tries to read from the provided ws stream. To prevent deadlock, waits for a second at most. +pub async fn receive_rpc_via_ws( + ws: &mut WebSocketStream>, +) -> Result { + let msg = tokio::time::timeout(Duration::from_secs(1), ws.try_next()) + .await?? + .ok_or(anyhow::Error::msg("Nothing to read"))?; + Ok(serde_json::from_str(&msg.into_text()?)?) +} + +pub async fn assert_no_notifications(ws: &mut WebSocketStream>) { + match receive_rpc_via_ws(ws).await { + Ok(resp) => panic!("Expected no notifications; found: {resp}"), + Err(e) if e.to_string().contains("deadline has elapsed") => { /* expected */ } + Err(e) => panic!("Expected to error out due to empty channel; found: {e}"), + } +} + #[cfg(test)] mod test_unique_auto_deletable_file { use std::path::Path; diff --git a/crates/starknet-devnet/tests/test_subscription_to_blocks.rs b/crates/starknet-devnet/tests/test_subscription_to_blocks.rs new file mode 100644 index 000000000..8ecdf5c91 --- /dev/null +++ b/crates/starknet-devnet/tests/test_subscription_to_blocks.rs @@ -0,0 +1,381 @@ +#![cfg(test)] +pub mod common; + +mod websocket_subscription_support { + use std::collections::HashMap; + use std::time::Duration; + + use serde_json::json; + use starknet_core::constants::ETH_ERC20_CONTRACT_ADDRESS; + use starknet_rs_core::types::{BlockId, BlockTag, Felt}; + use starknet_rs_providers::Provider; + use tokio::net::TcpStream; + use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; + + use crate::common::background_devnet::BackgroundDevnet; + use crate::common::utils::{assert_no_notifications, receive_rpc_via_ws, send_text_rpc_via_ws}; + + async fn subscribe_new_heads( + ws: &mut WebSocketStream>, + block_specifier: serde_json::Value, + ) -> Result { + let subscription_confirmation = + send_text_rpc_via_ws(ws, "starknet_subscribeNewHeads", block_specifier).await?; + subscription_confirmation["result"] + .as_i64() + .ok_or(anyhow::Error::msg("Subscription did not return a numeric ID")) + } + + async fn unsubscribe( + ws: &mut WebSocketStream>, + subscription_id: i64, + ) -> Result { + send_text_rpc_via_ws( + ws, + "starknet_unsubscribe", + json!({ "subscription_id": subscription_id }), + ) + .await + } + + #[tokio::test] + async fn subscribe_to_new_block_heads_happy_path() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + + let subscription_id = subscribe_new_heads(&mut ws, json!({})).await.unwrap(); + + // test with multiple blocks created, number 0 was origin, so we start at 1 + for block_i in 1..=2 { + let created_block_hash = devnet.create_block().await.unwrap(); + + let notification = receive_rpc_via_ws(&mut ws).await.unwrap(); + assert_eq!(notification["method"], "starknet_subscriptionNewHeads"); + assert_eq!( + notification["params"]["result"]["block_hash"].as_str().unwrap(), + created_block_hash.to_hex_string().as_str() + ); + + assert_eq!(notification["params"]["result"]["block_number"].as_i64().unwrap(), block_i); + assert_eq!( + notification["params"]["subscription_id"].as_i64().unwrap(), + subscription_id + ); + } + } + + #[tokio::test] + async fn should_not_receive_block_notification_if_not_subscribed() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + + devnet.create_block().await.unwrap(); + assert_no_notifications(&mut ws).await; + } + + #[tokio::test] + async fn multiple_block_subscribers_happy_path() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + + let n_subscribers = 5; + + let mut subscribers = HashMap::new(); + for _ in 0..n_subscribers { + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + let subscription_id = subscribe_new_heads(&mut ws, json!({})).await.unwrap(); + subscribers.insert(subscription_id, ws); + } + + assert_eq!(subscribers.len(), n_subscribers); // assert all IDs are different + + let created_block_hash = devnet.create_block().await.unwrap(); + + for (subscription_id, mut ws) in subscribers { + let notification = receive_rpc_via_ws(&mut ws).await.unwrap(); + assert_eq!(notification["method"], "starknet_subscriptionNewHeads"); + assert_eq!( + notification["params"]["result"]["block_hash"].as_str().unwrap(), + created_block_hash.to_hex_string().as_str() + ); + + assert_eq!(notification["params"]["result"]["block_number"].as_i64().unwrap(), 1); + assert_eq!( + notification["params"]["subscription_id"].as_i64().unwrap(), + subscription_id + ); + } + } + + #[tokio::test] + async fn subscription_to_an_old_block_by_number_should_notify_of_all_blocks_up_to_latest() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + + let n_blocks = 5; + for _ in 0..n_blocks { + devnet.create_block().await.unwrap(); + } + + // request notifications for all blocks starting with genesis + let subscription_id = + subscribe_new_heads(&mut ws, json!({ "block_id": BlockId::Number(0) })).await.unwrap(); + + for block_i in 0..=n_blocks { + let notification = receive_rpc_via_ws(&mut ws).await.unwrap(); + assert_eq!(notification["method"], "starknet_subscriptionNewHeads"); + + assert_eq!(notification["params"]["result"]["block_number"].as_i64().unwrap(), block_i); + assert_eq!( + notification["params"]["subscription_id"].as_i64().unwrap(), + subscription_id + ); + } + + assert_no_notifications(&mut ws).await; + } + + #[tokio::test] + async fn subscription_to_an_old_block_by_hash_should_notify_of_all_blocks_up_to_latest() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + + let genesis_block = devnet.get_latest_block_with_tx_hashes().await.unwrap(); + + let n_blocks = 5; + for _ in 0..n_blocks { + devnet.create_block().await.unwrap(); + } + + // request notifications for all blocks starting with genesis + let subscription_id = subscribe_new_heads( + &mut ws, + json!({ "block_id": BlockId::Hash(genesis_block.block_hash)}), + ) + .await + .unwrap(); + + let starting_block = 0; + for block_i in starting_block..=n_blocks { + let notification = receive_rpc_via_ws(&mut ws).await.unwrap(); + assert_eq!(notification["method"], "starknet_subscriptionNewHeads"); + + assert_eq!(notification["params"]["result"]["block_number"].as_i64().unwrap(), block_i); + assert_eq!( + notification["params"]["subscription_id"].as_i64().unwrap(), + subscription_id + ); + } + + assert_no_notifications(&mut ws).await; + } + + #[tokio::test] + async fn subscription_to_pending_block_is_same_as_latest() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + let (mut ws_latest, _) = connect_async(devnet.ws_url()).await.unwrap(); + let (mut ws_pending, _) = connect_async(devnet.ws_url()).await.unwrap(); + + // create two subscriptions: one to latest, one to pending + let subscription_id_latest = + subscribe_new_heads(&mut ws_latest, json!({ "block_id": "latest" })).await.unwrap(); + + let subscription_id_pending = + subscribe_new_heads(&mut ws_pending, json!({ "block_id": "pending" })).await.unwrap(); + + assert_ne!(subscription_id_latest, subscription_id_pending); + + devnet.create_block().await.unwrap(); + + // assert notification equality after taking subscription IDs out + let mut notification_latest = receive_rpc_via_ws(&mut ws_latest).await.unwrap(); + assert_eq!(notification_latest["params"]["subscription_id"].take(), subscription_id_latest); + assert_no_notifications(&mut ws_latest).await; + + let mut notification_pending = receive_rpc_via_ws(&mut ws_pending).await.unwrap(); + assert_eq!( + notification_pending["params"]["subscription_id"].take(), + subscription_id_pending + ); + assert_no_notifications(&mut ws_pending).await; + + assert_eq!(notification_latest, notification_pending); + } + + #[tokio::test] + async fn test_multiple_subscribers_one_unsubscribes() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + + let n_subscribers = 3; + + let mut subscribers = HashMap::new(); + for _ in 0..n_subscribers { + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + let subscription_id = subscribe_new_heads(&mut ws, json!({})).await.unwrap(); + subscribers.insert(subscription_id, ws); + } + + assert_eq!(subscribers.len(), n_subscribers); // assert all IDs are different + + // randomly choose one subscriber for unsubscription + let unsubscriber_id = *subscribers.keys().next().expect("Should have at least one"); + + // unsubscribe + let mut unsubscriber_ws = subscribers.remove(&unsubscriber_id).unwrap(); + let unsubscription_resp = unsubscribe(&mut unsubscriber_ws, unsubscriber_id).await.unwrap(); + assert_eq!(unsubscription_resp, json!({ "jsonrpc": "2.0", "id": 0, "result": true })); + + // create block and assert only subscribers are notified + let created_block_hash = devnet.create_block().await.unwrap(); + + for (subscription_id, mut ws) in subscribers { + let notification = receive_rpc_via_ws(&mut ws).await.unwrap(); + assert_eq!(notification["method"], "starknet_subscriptionNewHeads"); + assert_eq!( + notification["params"]["result"]["block_hash"].as_str().unwrap(), + created_block_hash.to_hex_string().as_str() + ); + assert_eq!( + notification["params"]["subscription_id"].as_i64().unwrap(), + subscription_id + ); + } + + assert_no_notifications(&mut unsubscriber_ws).await; + } + + #[tokio::test] + async fn test_unsubscribing_invalid_id() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + + let dummy_id = 123; + let unsubscription_resp = unsubscribe(&mut ws, dummy_id).await.unwrap(); + + assert_eq!( + unsubscription_resp, + json!({ + "jsonrpc": "2.0", + "id": 0, + "error": { + "code": 66, + "message": "Invalid subscription id", + } + }) + ); + } + + #[tokio::test] + async fn read_only_methods_do_not_generate_notifications() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + + subscribe_new_heads(&mut ws, json!({})).await.unwrap(); + + devnet + .json_rpc_client + .get_class_hash_at(BlockId::Tag(BlockTag::Latest), ETH_ERC20_CONTRACT_ADDRESS) + .await + .unwrap(); + + assert_no_notifications(&mut ws).await; + } + + #[tokio::test] + async fn test_notifications_in_block_on_demand_mode() { + let devnet_args = ["--block-generation-on", "demand"]; + let devnet = BackgroundDevnet::spawn_with_additional_args(&devnet_args).await.unwrap(); + + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + let subscription_id = subscribe_new_heads(&mut ws, json!({})).await.unwrap(); + + let dummy_address = 0x1; + devnet.mint(dummy_address, 1).await; + + assert_no_notifications(&mut ws).await; + + let created_block_hash = devnet.create_block().await.unwrap(); + + let notification = receive_rpc_via_ws(&mut ws).await.unwrap(); + assert_eq!(notification["method"], "starknet_subscriptionNewHeads"); + assert_eq!( + notification["params"]["result"]["block_hash"].as_str().unwrap(), + created_block_hash.to_hex_string().as_str() + ); + + assert_eq!(notification["params"]["result"]["block_number"].as_i64().unwrap(), 1); + assert_eq!(notification["params"]["subscription_id"].as_i64().unwrap(), subscription_id); + } + + #[tokio::test] + async fn test_notifications_on_periodic_block_generation() { + let interval = 3; + let devnet_args = ["--block-generation-on", &interval.to_string()]; + let devnet = BackgroundDevnet::spawn_with_additional_args(&devnet_args).await.unwrap(); + + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + let subscription_id = subscribe_new_heads(&mut ws, json!({})).await.unwrap(); + + assert_no_notifications(&mut ws).await; + + // should be enough time for Devnet to mine a single new block + tokio::time::sleep(Duration::from_secs(interval + 1)).await; + + let notification = receive_rpc_via_ws(&mut ws).await.unwrap(); + + assert_eq!(notification["method"], "starknet_subscriptionNewHeads"); + assert_eq!(notification["params"]["result"]["block_number"].as_i64().unwrap(), 1); + assert_eq!(notification["params"]["subscription_id"].as_i64().unwrap(), subscription_id); + + assert_no_notifications(&mut ws).await; + } + + #[tokio::test] + async fn test_subscribing_to_non_existent_block() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + + for block_id in [BlockId::Number(1), BlockId::Hash(Felt::ONE)] { + let subscription_resp = send_text_rpc_via_ws( + &mut ws, + "starknet_subscribeNewHeads", + json!({ "block_id": block_id }), + ) + .await + .unwrap(); + + assert_eq!( + subscription_resp, + json!({ "jsonrpc": "2.0", "id": 0, "error": { "code": 24, "message": "Block not found" } }) + ); + } + } + + #[tokio::test] + async fn test_aborted_blocks_not_subscribable() { + let devnet_args = ["--state-archive-capacity", "full"]; + let devnet = BackgroundDevnet::spawn_with_additional_args(&devnet_args).await.unwrap(); + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + + let new_block_hash = devnet.create_block().await.unwrap(); + devnet + .send_custom_rpc( + "devnet_abortBlocks", + json!({ "starting_block_id": BlockId::Hash(new_block_hash) }), + ) + .await + .unwrap(); + + let subscription_resp = send_text_rpc_via_ws( + &mut ws, + "starknet_subscribeNewHeads", + json!({ "block_id": BlockId::Hash(new_block_hash) }), + ) + .await + .unwrap(); + + assert_eq!( + subscription_resp, + json!({ "jsonrpc": "2.0", "id": 0, "error": { "code": 24, "message": "Block not found" } }) + ); + } +} diff --git a/crates/starknet-devnet/tests/test_websocket.rs b/crates/starknet-devnet/tests/test_websocket.rs index f149edd79..7a5c943df 100644 --- a/crates/starknet-devnet/tests/test_websocket.rs +++ b/crates/starknet-devnet/tests/test_websocket.rs @@ -2,58 +2,37 @@ pub mod common; mod websocket_support { - use futures::{SinkExt, StreamExt}; use serde_json::json; use starknet_rs_core::types::Felt; use starknet_types::rpc::transaction_receipt::FeeUnit; - use tokio::net::TcpStream; - use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; + use tokio_tungstenite::connect_async; use crate::common::background_devnet::BackgroundDevnet; + use crate::common::utils::{send_binary_rpc_via_ws, send_text_rpc_via_ws}; - async fn send_text_rpc_via_ws( - ws: &mut WebSocketStream>, - method: &str, - params: serde_json::Value, - ) -> Result { - let text_body = json!({ - "jsonrpc": "2.0", - "id": 0, - "method": method, - "params": params - }) - .to_string(); - ws.send(tokio_tungstenite::tungstenite::Message::Text(text_body)).await?; - - let resp_raw = - ws.next().await.ok_or(anyhow::Error::msg("No response in websocket stream"))??; - let resp_body: serde_json::Value = serde_json::from_slice(&resp_raw.into_data())?; - - Ok(resp_body) - } + #[tokio::test] + /// Testing for all non-ws methods would be longsome, so we just test for one devnet_ and one + /// starknet_ method + async fn test_general_rpc_support_via_websocket_is_disabled() { + let devnet = BackgroundDevnet::spawn().await.unwrap(); + let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); + + let expected_resp = + json!({"jsonrpc":"2.0", "id":0, "error":{"code":-32601, "message":"Method not found"}}); + + assert_eq!( + send_text_rpc_via_ws(&mut ws, "devnet_mint", json!({})).await.unwrap(), + expected_resp, + ); - async fn send_binary_rpc_via_ws( - ws: &mut WebSocketStream>, - method: &str, - params: serde_json::Value, - ) -> Result { - let body = json!({ - "jsonrpc": "2.0", - "id": 0, - "method": method, - "params": params - }); - let binary_body = serde_json::to_vec(&body)?; - ws.send(tokio_tungstenite::tungstenite::Message::Binary(binary_body)).await?; - - let resp_raw = - ws.next().await.ok_or(anyhow::Error::msg("No response in websocket stream"))??; - let resp_body: serde_json::Value = serde_json::from_slice(&resp_raw.into_data())?; - - Ok(resp_body) + assert_eq!( + send_text_rpc_via_ws(&mut ws, "starknet_syncing", json!({})).await.unwrap(), + expected_resp, + ); } #[tokio::test] + #[ignore = "General RPC support via websocket is disabled"] async fn mint_and_check_tx_via_websocket() { let devnet = BackgroundDevnet::spawn().await.unwrap(); let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); @@ -84,6 +63,7 @@ mod websocket_support { } #[tokio::test] + #[ignore = "General RPC support via websocket is disabled"] async fn create_block_via_binary_ws_message() { let devnet = BackgroundDevnet::spawn().await.unwrap(); let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); @@ -107,6 +87,7 @@ mod websocket_support { } #[tokio::test] + #[ignore = "General RPC support via websocket is disabled"] async fn multiple_ws_connections() { let devnet = BackgroundDevnet::spawn().await.unwrap(); let iterations = 10; @@ -137,6 +118,7 @@ mod websocket_support { } #[tokio::test] + #[ignore = "General RPC support via websocket is disabled"] async fn invalid_request() { let devnet = BackgroundDevnet::spawn().await.unwrap(); let (mut ws, _) = connect_async(devnet.ws_url()).await.unwrap(); diff --git a/website/docs/api.md b/website/docs/api.md index ab87f40cf..f1d65138c 100644 --- a/website/docs/api.md +++ b/website/docs/api.md @@ -28,13 +28,13 @@ To check if a Devnet instance is alive, send an HTTP request `GET /is_alive`. If ### WebSocket -All JSON-RPC methods can be accessed via the WebSocket protocol. Devnet is listening for new WebSocket connections at `ws://:/ws` (notice the protocol scheme). Any request body you would send to `/rpc` you can send as a text (or binary) message via WebSocket. E.g. using [`wscat`](https://www.npmjs.com/package/wscat) on the same computer where Devnet is spawned at default host and port: +JSON-RPC websocket methods can be accessed via the WebSocket protocol. Devnet is listening for new WebSocket connections at `ws://:/ws` (notice the protocol scheme). Any request body you would send to `/rpc` you can send as a text (or binary) message via WebSocket. E.g. using [`wscat`](https://www.npmjs.com/package/wscat) on the same computer where Devnet is spawned at default host and port: ``` $ wscat -c ws://127.0.0.1:5050/ws Connected (press CTRL+C to quit) -> { "jsonrpc": "2.0", "id": 0, "method": "devnet_mint", "params": { "amount": 10, "address": "0xabc" } } -< {"jsonrpc":"2.0","id":0,"result":{"new_balance":"10","unit":"WEI","tx_hash":"0x22aef5a981d547b4e8100b83d4ef82e69dff28e5888c6b1320f38da3a379ad5"}} +> { "jsonrpc": "2.0", "id": 0, "method": "starknet_subscribeNewHeads" } +< {"id":0,"result":2935616350010920547,"jsonrpc":"2.0"} ``` ## Interacting with Devnet in JavaScript and TypeScript