Skip to content

Commit

Permalink
Refactor ws error sending [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
FabijanC committed Nov 4, 2024
1 parent e89dae2 commit 2764baf
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 91 deletions.
159 changes: 70 additions & 89 deletions crates/starknet-devnet-server/src/api/json_rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use models::{
EstimateFeeInput, EventsInput, GetStorageInput, L1TransactionHashInput, SubscriptionIdInput,
TransactionHashInput, TransactionHashOutput,
};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::json;
use starknet_core::starknet::starknet_config::{DumpOn, StarknetConfig};
Expand Down Expand Up @@ -129,37 +130,17 @@ impl RpcHandler for JsonRpcHandler {

async fn on_call(&self, call: RpcMethodCall) -> RpcResponse {
trace!(target: "rpc", id = ?call.id , method = ?call.method, "received method call");
let RpcMethodCall { method, params, id, .. } = call.clone();

let params: serde_json::Value = params.into();
let deserializable_call = serde_json::json!({
"method": &method,
"params": params
});
if !self.allows_method(&call.method) {
return RpcError::new(ErrorCode::MethodForbidden).into();
}

match serde_json::from_value::<Self::Request>(deserializable_call) {
match to_json_rpc_request(&call) {
Ok(req) => {
if let Some(restricted_methods) = &self.server_config.restricted_methods {
if is_json_rpc_method_restricted(&method, restricted_methods) {
return RpcResponse::new(id, RpcError::new(ErrorCode::MethodForbidden));
}
}
let result = self.on_request(req, call).await;
RpcResponse::new(id, result)
}
Err(err) => {
let err = err.to_string();
// since JSON-RPC specification requires returning a Method Not Found error,
// we apply a hacky way to induce this - checking the stringified error message
let distinctive_error = format!("unknown variant `{method}`");
if err.contains(&distinctive_error) {
error!(target: "rpc", ?method, "failed to deserialize method due to unknown variant");
RpcResponse::new(id, RpcError::method_not_found())
} else {
error!(target: "rpc", ?method, ?err, "failed to deserialize method");
RpcResponse::new(id, RpcError::invalid_params(err))
}
let result = self.on_request(req, call.clone()).await;
RpcResponse::new(call.id, result)
}
Err(e) => e.into(),
}
}

Expand Down Expand Up @@ -459,80 +440,55 @@ impl JsonRpcHandler {
ws: Arc<Mutex<SplitSink<WebSocket, Message>>>,
socket_id: SocketId,
) {
match serde_json::from_slice(bytes) {
let error_serialized = match serde_json::from_slice(bytes) {
Ok(call) => {
self.on_websocket_rpc_call(call, ws, socket_id).await;
// TODO removed general RPC method support - update docs
}
Err(e) => {
if let Err(e) = ws.lock().await.send(Message::Text(e.to_string())).await {
tracing::error!("Error sending websocket message: {e}");
if let Err(e) = self.on_websocket_rpc_call(&call, socket_id).await {
let rpc_error = serde_json::json!({
"jsonrpc": "2.0",
"id": call.id,
"error": e
});
rpc_error.to_string()
} else {
return;
}
}
Err(e) => e.to_string(),
};

if let Err(e) = ws.lock().await.send(Message::Text(error_serialized)).await {
tracing::error!("Error sending websocket message: {e}");
}
}

fn allows_method(&self, method: &String) -> bool {
if let Some(restricted_methods) = &self.server_config.restricted_methods {
if is_json_rpc_method_restricted(&method, restricted_methods) {
return false;
}
}

true
}

/// TODO this method contains duplication from `on_call`
/// Since some subscriptions might need to send multiple messages, sending messages other than
/// errors is left to individual RPC method handlers and this method returns an empty successful
/// Result.
async fn on_websocket_rpc_call(
&self,
call: RpcMethodCall,
ws: Arc<Mutex<SplitSink<WebSocket, Message>>>,
call: &RpcMethodCall,
socket_id: SocketId,
) {
) -> Result<(), RpcError> {
trace!(target: "rpc", id = ?call.id , method = ?call.method, "received method call");
let RpcMethodCall { method, params, id: rpc_request_id, .. } = call.clone();

let params: serde_json::Value = params.into();
let deserializable_call = serde_json::json!({
"method": &method,
"params": params
});

match serde_json::from_value::<JsonRpcSubscriptionRequest>(deserializable_call) {
Ok(req) => {
if let Some(restricted_methods) = &self.server_config.restricted_methods {
if is_json_rpc_method_restricted(&method, restricted_methods) {
let err = RpcResponse::new(
rpc_request_id.clone(),
RpcError::new(ErrorCode::MethodForbidden),
);
let err_serialized = serde_json::to_string(&err)
.unwrap_or(format!("Unserializable: {err:?}"));
if let Err(e) = ws.lock().await.send(Message::Text(err_serialized)).await {
error!("Failed sending message: {e:?}")
}
}
}

if let Err(e) = self.execute_ws(req, rpc_request_id.clone(), socket_id).await {
let rpc_err = e.api_error_to_rpc_error();
let rpc_err_serialized = serde_json::to_string(&rpc_err)
.unwrap_or(format!("Unserializable: {rpc_err:?}"));
if let Err(e) = ws.lock().await.send(Message::Text(rpc_err_serialized)).await {
error!("Failed sending message: {e:?}");
}
}
}
Err(err) => {
let err = err.to_string();
// since JSON-RPC specification requires returning a Method Not Found error,
// we apply a hacky way to induce this - checking the stringified error message
let distinctive_error = format!("unknown variant `{method}`");
let rpc_err = if err.contains(&distinctive_error) {
error!(target: "rpc", ?method, "failed to deserialize method due to unknown variant");
RpcResponse::new(rpc_request_id, RpcError::method_not_found())
} else {
error!(target: "rpc", ?method, ?err, "failed to deserialize method");
RpcResponse::new(rpc_request_id, RpcError::invalid_params(err))
};

let rpc_err_serialized = serde_json::to_string(&rpc_err)
.unwrap_or(format!("Unserializable: {rpc_err:?}"));
if let Err(e) = ws.lock().await.send(Message::Text(rpc_err_serialized)).await {
error!("Failed sending message: {e}");
}
}
if !self.allows_method(&call.method) {
return Err(RpcError::new(ErrorCode::MethodForbidden));
}

let req = to_json_rpc_request(call)?;
self.execute_ws(req, call.id.clone(), socket_id)
.await
.map_err(|e| e.api_error_to_rpc_error())
}

const DUMPABLE_METHODS: &'static [&'static str] = &[
Expand Down Expand Up @@ -717,6 +673,31 @@ pub enum JsonRpcSubscriptionRequest {
Unsubscribe(SubscriptionIdInput),
}

fn to_json_rpc_request<D>(call: &RpcMethodCall) -> Result<D, RpcError>
where
D: DeserializeOwned,
{
let params: serde_json::Value = call.params.clone().into();
let deserializable_call = serde_json::json!({
"method": call.method,
"params": params
});

serde_json::from_value::<D>(deserializable_call).map_err(|err| {
let err = err.to_string();
// since JSON-RPC specification requires returning a Method Not Found error,
// we apply a hacky way to induce this - checking the stringified error message
let distinctive_error = format!("unknown variant `{}`", call.method);
if err.contains(&distinctive_error) {
error!(target: "rpc", method = ?call.method, "failed to deserialize method due to unknown variant");
RpcError::method_not_found()
} else {
error!(target: "rpc", method = ?call.method, ?err, "failed to deserialize method");
RpcError::invalid_params(err)
}
})
}

impl std::fmt::Display for JsonRpcRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.variant_name())
Expand Down
8 changes: 6 additions & 2 deletions crates/starknet-devnet/tests/test_subscription_to_blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,12 @@ mod websocket_subscription_support {
assert_eq!(
unsubscription_resp,
json!({
"code": 66,
"message": "Invalid subscription id"
"jsonrpc": "2.0",
"id": 0,
"error": {
"code": 66,
"message": "Invalid subscription id",
}
})
);
}
Expand Down

0 comments on commit 2764baf

Please sign in to comment.