diff --git a/src/ws/raw/core.rs b/src/ws/raw/core.rs index 701471763a..3ae9ccf2c6 100644 --- a/src/ws/raw/core.rs +++ b/src/ws/raw/core.rs @@ -31,6 +31,7 @@ use crate::ws::transport::{TransportServerEvent, WsRequestId as RequestId, WsTra use alloc::{borrow::ToOwned as _, string::String, vec, vec::Vec}; use core::{fmt, hash::Hash, num::NonZeroUsize}; use hashbrown::{hash_map::Entry, HashMap}; +use std::convert::TryFrom; /// Wraps around a "raw server" and adds capabilities. /// @@ -399,6 +400,22 @@ impl RawServerSubscriptionId { } } +// Try to parse a subscription ID from `Params` where we try both index 0 of an array or `subscription` +// in a `Map`. +impl<'a> TryFrom> for RawServerSubscriptionId { + type Error = (); + + fn try_from(params: Params) -> Result { + if let Ok(other_id) = params.get(0) { + Self::from_wire_message(&other_id) + } else if let Ok(other_id) = params.get("subscription") { + Self::from_wire_message(&other_id) + } else { + Err(()) + } + } +} + impl<'a> ServerSubscription<'a> { /// Returns the id of the subscription. /// diff --git a/src/ws/server.rs b/src/ws/server.rs index a049c977bc..aaec8ab01b 100644 --- a/src/ws/server.rs +++ b/src/ws/server.rs @@ -32,6 +32,7 @@ use futures::{channel::mpsc, future::Either, pin_mut, prelude::*}; use parking_lot::Mutex; use std::{ collections::{HashMap, HashSet}, + convert::TryFrom, error, net::SocketAddr, sync::{atomic, Arc}, @@ -410,8 +411,8 @@ async fn background_task(mut server: RawServer, mut from_front: mpsc::UnboundedR } } Either::Right(RawServerEvent::Request(request)) => { - log::debug!("[backend]: server received request: {:?}", request); if let Some(handler) = registered_methods.get_mut(request.method()) { + log::debug!("[backend]: server received request: {:?}", request); let params: &common::Params = request.params().into(); log::debug!("server called handler"); match handler.send((request.id(), params.clone())).now_or_never() { @@ -421,6 +422,7 @@ async fn background_task(mut server: RawServer, mut from_front: mpsc::UnboundedR } } } else if let Some(sub_unique_id) = subscribe_methods.get(request.method()) { + log::debug!("[backend]: server received subscription: {:?}", request); if let Ok(sub_id) = request.into_subscription() { debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { @@ -432,21 +434,24 @@ async fn background_task(mut server: RawServer, mut from_front: mpsc::UnboundedR active_subscriptions.insert(sub_id, *sub_unique_id); } } else if let Some(sub_unique_id) = unsubscribe_methods.get(request.method()) { - if let Ok(sub_id) = RawServerSubscriptionId::from_wire_message(&JsonValue::Null) { - // FIXME: from request params - debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); - if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { - // TODO: we don't actually check whether the unsubscribe comes from the right - // clients, but since this the ID is randomly-generated, it should be - // fine - if let Some(client_pos) = clients.iter().position(|c| *c == sub_id) { - clients.remove(client_pos); - } - - if let Some(s_u_id) = active_subscriptions.remove(&sub_id) { - debug_assert_eq!(s_u_id, *sub_unique_id); + log::debug!("[backend]: server received unsubscription: {:?}", request); + match RawServerSubscriptionId::try_from(request.params()) { + Ok(sub_id) => { + debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); + if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { + // TODO: we don't actually check whether the unsubscribe comes from the right + // clients, but since this the ID is randomly-generated, it should be + // fine + if let Some(client_pos) = clients.iter().position(|c| *c == sub_id) { + clients.remove(client_pos); + } + + if let Some(s_u_id) = active_subscriptions.remove(&sub_id) { + debug_assert_eq!(s_u_id, *sub_unique_id); + } } } + Err(_) => log::error!("Unsubscription of method=\"{}\" failed; The subscription ID must passed as the first argument of Array or \"subscription\" name of Object, got={:?}", request.method(), request.params()), } } else { // TODO: we assert that the request is valid because the parsing succeeded but