Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow arbitrary strings in subscription ids #1163

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/json-rpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ mod error;
pub use error::RpcError;

mod notification;
pub use notification::{EthNotification, PubSubItem};
pub use notification::{EthNotification, PubSubItem, SubId};

mod packet;
pub use packet::{BorrowedResponsePacket, RequestPacket, ResponsePacket};
Expand Down
19 changes: 16 additions & 3 deletions crates/json-rpc/src/notification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@ use serde::{
Deserialize, Serialize,
};

/// A subscription ID.
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
#[serde(untagged)]
pub enum SubId {
/// A number.
Number(U256),
/// A string.
String(String),
}

/// An ethereum-style notification, not to be confused with a JSON-RPC
/// notification.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EthNotification<T = Box<serde_json::value::RawValue>> {
/// The subscription ID.
pub subscription: U256,
pub subscription: SubId,
/// The notification payload.
pub result: T,
}
Expand Down Expand Up @@ -128,7 +138,7 @@ impl<'de> Deserialize<'de> for PubSubItem {
#[cfg(test)]
mod test {

use crate::{EthNotification, PubSubItem};
use crate::{EthNotification, PubSubItem, SubId};

#[test]
fn deserializer_test() {
Expand All @@ -140,7 +150,10 @@ mod test {

match deser {
PubSubItem::Notification(EthNotification { subscription, result }) => {
assert_eq!(subscription, "0xcd0c3e8af590364c09d0fa6a1210faf5".parse().unwrap());
assert_eq!(
subscription,
SubId::Number("0xcd0c3e8af590364c09d0fa6a1210faf5".parse().unwrap())
);
assert_eq!(result.get(), r#"{"difficulty": "0xd9263f42a87", "uncles": []}"#);
}
_ => panic!("unexpected deserialization result"),
Expand Down
7 changes: 3 additions & 4 deletions crates/pubsub/src/managers/in_flight.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use alloy_json_rpc::{Response, ResponsePayload, SerializedRequest};
use alloy_primitives::U256;
use alloy_json_rpc::{Response, ResponsePayload, SerializedRequest, SubId};
use alloy_transport::{TransportError, TransportResult};
use std::fmt;
use tokio::sync::oneshot;
Expand Down Expand Up @@ -55,10 +54,10 @@ impl InFlight {
/// Fulfill the request with a response. This consumes the in-flight
/// request. If the request is a subscription and the response is not an
/// error, the subscription ID and the in-flight request are returned.
pub(crate) fn fulfill(self, resp: Response) -> Option<(U256, Self)> {
pub(crate) fn fulfill(self, resp: Response) -> Option<(SubId, Self)> {
if self.is_subscription() {
if let ResponsePayload::Success(val) = resp.payload {
let sub_id: serde_json::Result<U256> = serde_json::from_str(val.get());
let sub_id: serde_json::Result<SubId> = serde_json::from_str(val.get());
return match sub_id {
Ok(alias) => Some((alias, self)),
Err(e) => {
Expand Down
5 changes: 2 additions & 3 deletions crates/pubsub/src/managers/req.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::managers::InFlight;
use alloy_json_rpc::{Id, Response};
use alloy_primitives::U256;
use alloy_json_rpc::{Id, Response, SubId};
use std::collections::BTreeMap;

/// Manages in-flight requests.
Expand Down Expand Up @@ -30,7 +29,7 @@ impl RequestManager {
/// If the request created a new subscription, this function returns the
/// subscription ID and the in-flight request for conversion to an
/// `ActiveSubscription`.
pub(crate) fn handle_response(&mut self, resp: Response) -> Option<(U256, InFlight)> {
pub(crate) fn handle_response(&mut self, resp: Response) -> Option<(SubId, InFlight)> {
if let Some(in_flight) = self.reqs.remove(&resp.id) {
return in_flight.fulfill(resp);
}
Expand Down
18 changes: 9 additions & 9 deletions crates/pubsub/src/managers/sub.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use crate::{managers::ActiveSubscription, RawSubscription};
use alloy_json_rpc::{EthNotification, SerializedRequest};
use alloy_primitives::{B256, U256};
use alloy_json_rpc::{EthNotification, SerializedRequest, SubId};
use alloy_primitives::B256;
use bimap::BiBTreeMap;

#[derive(Debug, Default)]
pub(crate) struct SubscriptionManager {
/// The subscriptions.
local_to_sub: BiBTreeMap<B256, ActiveSubscription>,
/// Tracks the CURRENT server id for a subscription.
local_to_server: BiBTreeMap<B256, U256>,
local_to_server: BiBTreeMap<B256, SubId>,
}

impl SubscriptionManager {
Expand All @@ -26,7 +26,7 @@ impl SubscriptionManager {
fn insert(
&mut self,
request: SerializedRequest,
server_id: U256,
server_id: SubId,
channel_size: usize,
) -> RawSubscription {
let active = ActiveSubscription::new(request, channel_size);
Expand All @@ -43,7 +43,7 @@ impl SubscriptionManager {
pub(crate) fn upsert(
&mut self,
request: SerializedRequest,
server_id: U256,
server_id: SubId,
channel_size: usize,
) -> RawSubscription {
let local_id = request.params_hash();
Expand All @@ -59,8 +59,8 @@ impl SubscriptionManager {
}

/// De-alias an alias, getting the original ID.
pub(crate) fn local_id_for(&self, server_id: U256) -> Option<B256> {
self.local_to_server.get_by_right(&server_id).copied()
pub(crate) fn local_id_for(&self, server_id: &SubId) -> Option<B256> {
self.local_to_server.get_by_right(server_id).copied()
}

/// Drop all server_ids.
Expand All @@ -69,7 +69,7 @@ impl SubscriptionManager {
}

/// Change the server_id of a subscription.
fn change_server_id(&mut self, local_id: B256, server_id: U256) {
fn change_server_id(&mut self, local_id: B256, server_id: SubId) {
self.local_to_server.insert(local_id, server_id);
}

Expand All @@ -83,7 +83,7 @@ impl SubscriptionManager {
/// and if any receiver exists. If the sub id is unknown, or no receiver
/// exists, the notification is dropped.
pub(crate) fn notify(&mut self, notification: EthNotification) {
if let Some(local_id) = self.local_id_for(notification.subscription) {
if let Some(local_id) = self.local_id_for(&notification.subscription) {
if let Some((_, mut sub)) = self.local_to_sub.remove_by_left(&local_id) {
sub.notify(notification.result);
self.local_to_sub.insert(local_id, sub);
Expand Down
14 changes: 8 additions & 6 deletions crates/pubsub/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
managers::{InFlight, RequestManager, SubscriptionManager},
PubSubConnect, PubSubFrontend, RawSubscription,
};
use alloy_json_rpc::{Id, PubSubItem, Request, Response, ResponsePayload};
use alloy_json_rpc::{Id, PubSubItem, Request, Response, ResponsePayload, SubId};
use alloy_primitives::U256;
use alloy_transport::{
utils::{to_json_raw_value, Spawnable},
Expand Down Expand Up @@ -167,16 +167,18 @@ impl<T: PubSubConnect> PubSubService<T> {
}

/// Rewrite the subscription id and insert into the subscriptions manager
fn handle_sub_response(&mut self, in_flight: InFlight, server_id: U256) -> TransportResult<()> {
fn handle_sub_response(
&mut self,
in_flight: InFlight,
server_id: SubId,
) -> TransportResult<()> {
let request = in_flight.request;
let id = request.id().clone();

self.subs.upsert(request, server_id, in_flight.channel_size);
let sub = self.subs.upsert(request, server_id, in_flight.channel_size);

// lie to the client about the sub id.
let local_id = self.subs.local_id_for(server_id).unwrap();
// Serialized B256 is always a valid serialized U256 too.
let ser_alias = to_json_raw_value(&local_id)?;
let ser_alias = to_json_raw_value(sub.local_id())?;

// We send back a success response with the new subscription ID.
// We don't care if the channel is dead.
Expand Down