Skip to content

Commit

Permalink
feat(trie): deserialize trie updates with serde as hex (#11369)
Browse files Browse the repository at this point in the history
  • Loading branch information
shekhirin authored Oct 1, 2024
1 parent 7fab4c3 commit 86f12b7
Showing 1 changed file with 137 additions and 43 deletions.
180 changes: 137 additions & 43 deletions crates/trie/trie/src/updates.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use crate::{walker::TrieWalker, BranchNodeCompact, HashBuilder, Nibbles};
use alloy_primitives::B256;
#[cfg(feature = "serde")]
use serde::{ser::SerializeMap, Serialize, Serializer};
use std::collections::{HashMap, HashSet};

/// The aggregation of trie updates.
#[derive(PartialEq, Eq, Clone, Default, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TrieUpdates {
#[cfg_attr(feature = "serde", serde(serialize_with = "serialize_nibbles_map"))]
#[cfg_attr(feature = "serde", serde(with = "serde_nibbles_map"))]
pub(crate) account_nodes: HashMap<Nibbles, BranchNodeCompact>,
#[cfg_attr(feature = "serde", serde(serialize_with = "serialize_nibbles_set"))]
#[cfg_attr(feature = "serde", serde(with = "serde_nibbles_set"))]
pub(crate) removed_nodes: HashSet<Nibbles>,
pub(crate) storage_tries: HashMap<B256, StorageTrieUpdates>,
}
Expand Down Expand Up @@ -117,10 +115,10 @@ pub struct StorageTrieUpdates {
/// Flag indicating whether the trie was deleted.
pub(crate) is_deleted: bool,
/// Collection of updated storage trie nodes.
#[cfg_attr(feature = "serde", serde(serialize_with = "serialize_nibbles_map"))]
#[cfg_attr(feature = "serde", serde(with = "serde_nibbles_map"))]
pub(crate) storage_nodes: HashMap<Nibbles, BranchNodeCompact>,
/// Collection of removed storage trie nodes.
#[cfg_attr(feature = "serde", serde(serialize_with = "serialize_nibbles_set"))]
#[cfg_attr(feature = "serde", serde(with = "serde_nibbles_set"))]
pub(crate) removed_nodes: HashSet<Nibbles>,
}

Expand Down Expand Up @@ -222,40 +220,118 @@ impl StorageTrieUpdates {
}
}

/// Serializes any [`HashSet`] that includes [`Nibbles`] elements, by using the hex-encoded packed
/// representation.
/// Serializes and deserializes any [`HashSet`] that includes [`Nibbles`] elements, by using the
/// hex-encoded packed representation.
///
/// This also sorts the set before serializing.
#[cfg(feature = "serde")]
fn serialize_nibbles_set<S>(map: &HashSet<Nibbles>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut storage_nodes =
Vec::from_iter(map.iter().map(|elem| alloy_primitives::hex::encode(elem.pack())));
storage_nodes.sort_unstable();
storage_nodes.serialize(serializer)
mod serde_nibbles_set {
use std::collections::HashSet;

use reth_trie_common::Nibbles;
use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};

pub(super) fn serialize<S>(map: &HashSet<Nibbles>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut storage_nodes =
Vec::from_iter(map.iter().map(|elem| alloy_primitives::hex::encode(elem.pack())));
storage_nodes.sort_unstable();
storage_nodes.serialize(serializer)
}

pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<HashSet<Nibbles>, D::Error>
where
D: Deserializer<'de>,
{
Vec::<String>::deserialize(deserializer)?
.into_iter()
.map(|node| {
Ok(Nibbles::unpack(
alloy_primitives::hex::decode(node)
.map_err(|err| D::Error::custom(err.to_string()))?,
))
})
.collect::<Result<HashSet<_>, _>>()
}
}

/// Serializes any [`HashMap`] that uses [`Nibbles`] as keys, by using the hex-encoded packed
/// representation.
/// Serializes and deserializes any [`HashMap`] that uses [`Nibbles`] as keys, by using the
/// hex-encoded packed representation.
///
/// This also sorts the map's keys before encoding and serializing.
#[cfg(feature = "serde")]
fn serialize_nibbles_map<S, T>(map: &HashMap<Nibbles, T>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: Serialize,
{
let mut map_serializer = serializer.serialize_map(Some(map.len()))?;
let mut storage_nodes = Vec::from_iter(map);
storage_nodes.sort_unstable_by(|a, b| a.0.cmp(b.0));
for (k, v) in storage_nodes {
// pack, then hex encode the Nibbles
let packed = alloy_primitives::hex::encode(k.pack());
map_serializer.serialize_entry(&packed, &v)?;
}
map_serializer.end()
mod serde_nibbles_map {
use std::{collections::HashMap, marker::PhantomData};

use alloy_primitives::hex;
use reth_trie_common::Nibbles;
use serde::{
de::{Error, MapAccess, Visitor},
ser::SerializeMap,
Deserialize, Deserializer, Serialize, Serializer,
};

pub(super) fn serialize<S, T>(
map: &HashMap<Nibbles, T>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: Serialize,
{
let mut map_serializer = serializer.serialize_map(Some(map.len()))?;
let mut storage_nodes = Vec::from_iter(map);
storage_nodes.sort_unstable_by_key(|node| node.0);
for (k, v) in storage_nodes {
// pack, then hex encode the Nibbles
let packed = alloy_primitives::hex::encode(k.pack());
map_serializer.serialize_entry(&packed, &v)?;
}
map_serializer.end()
}

pub(super) fn deserialize<'de, D, T>(deserializer: D) -> Result<HashMap<Nibbles, T>, D::Error>
where
D: Deserializer<'de>,
T: Deserialize<'de>,
{
struct NibblesMapVisitor<T> {
marker: PhantomData<T>,
}

impl<'de, T> Visitor<'de> for NibblesMapVisitor<T>
where
T: Deserialize<'de>,
{
type Value = HashMap<Nibbles, T>;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a map with hex-encoded Nibbles keys")
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut result = HashMap::with_capacity(map.size_hint().unwrap_or(0));

while let Some((key, value)) = map.next_entry::<String, T>()? {
let decoded_key =
hex::decode(&key).map_err(|err| Error::custom(err.to_string()))?;

let nibbles = Nibbles::unpack(&decoded_key);

result.insert(nibbles, value);
}

Ok(result)
}
}

deserializer.deserialize_map(NibblesMapVisitor { marker: PhantomData })
}
}

/// Sorted trie updates used for lookups and insertions.
Expand Down Expand Up @@ -325,33 +401,51 @@ mod tests {
use super::*;

#[test]
fn test_serialize_trie_updates_works() {
fn test_trie_updates_serde_roundtrip() {
let mut default_updates = TrieUpdates::default();
let _updates_string = serde_json::to_string(&default_updates).unwrap();
let updates_serialized = serde_json::to_string(&default_updates).unwrap();
let updates_deserialized: TrieUpdates = serde_json::from_str(&updates_serialized).unwrap();
assert_eq!(updates_deserialized, default_updates);

default_updates.removed_nodes.insert(Nibbles::from_vec(vec![0x0b, 0x0e, 0x0e, 0x0f]));
let _updates_string = serde_json::to_string(&default_updates).unwrap();
let updates_serialized = serde_json::to_string(&default_updates).unwrap();
let updates_deserialized: TrieUpdates = serde_json::from_str(&updates_serialized).unwrap();
assert_eq!(updates_deserialized, default_updates);

default_updates
.account_nodes
.insert(Nibbles::from_vec(vec![0x0b, 0x0e, 0x0f]), BranchNodeCompact::default());
let _updates_string = serde_json::to_string(&default_updates).unwrap();
.insert(Nibbles::from_vec(vec![0x0d, 0x0e, 0x0a, 0x0d]), BranchNodeCompact::default());
let updates_serialized = serde_json::to_string(&default_updates).unwrap();
println!("{updates_serialized}");
let updates_deserialized: TrieUpdates = serde_json::from_str(&updates_serialized).unwrap();
assert_eq!(updates_deserialized, default_updates);

default_updates.storage_tries.insert(B256::default(), StorageTrieUpdates::default());
let _updates_string = serde_json::to_string(&default_updates).unwrap();
let updates_serialized = serde_json::to_string(&default_updates).unwrap();
let updates_deserialized: TrieUpdates = serde_json::from_str(&updates_serialized).unwrap();
assert_eq!(updates_deserialized, default_updates);
}

#[test]
fn test_serialize_storage_trie_updates_works() {
fn test_storage_trie_updates_serde_roundtrip() {
let mut default_updates = StorageTrieUpdates::default();
let _updates_string = serde_json::to_string(&default_updates).unwrap();
let updates_serialized = serde_json::to_string(&default_updates).unwrap();
let updates_deserialized: StorageTrieUpdates =
serde_json::from_str(&updates_serialized).unwrap();
assert_eq!(updates_deserialized, default_updates);

default_updates.removed_nodes.insert(Nibbles::from_vec(vec![0x0b, 0x0e, 0x0e, 0x0f]));
let _updates_string = serde_json::to_string(&default_updates).unwrap();
let updates_serialized = serde_json::to_string(&default_updates).unwrap();
let updates_deserialized: StorageTrieUpdates =
serde_json::from_str(&updates_serialized).unwrap();
assert_eq!(updates_deserialized, default_updates);

default_updates
.storage_nodes
.insert(Nibbles::from_vec(vec![0x0b, 0x0e, 0x0f]), BranchNodeCompact::default());
let _updates_string = serde_json::to_string(&default_updates).unwrap();
.insert(Nibbles::from_vec(vec![0x0d, 0x0e, 0x0a, 0x0d]), BranchNodeCompact::default());
let updates_serialized = serde_json::to_string(&default_updates).unwrap();
let updates_deserialized: StorageTrieUpdates =
serde_json::from_str(&updates_serialized).unwrap();
assert_eq!(updates_deserialized, default_updates);
}
}

0 comments on commit 86f12b7

Please sign in to comment.