diff --git a/Cargo.toml b/Cargo.toml index ea4e7e85..630e7938 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ sanitize = ['crossbeam-epoch/sanitize'] crossbeam-epoch = "0.8.2" parking_lot = "0.10" num_cpus = "1.12.0" +serde = {version = "1.0.105", optional = true} [dependencies.ahash] version = "0.3.2" @@ -35,6 +36,7 @@ default-features = false rand = "0.7" rayon = "1.3" criterion = "0.3" +serde_json = "1.0.50" [[bench]] name = "flurry_dashmap" diff --git a/src/lib.rs b/src/lib.rs index 7340f053..5060ba24 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -223,6 +223,9 @@ mod raw; mod set; mod set_ref; +#[cfg(feature = "serde")] +mod serde_impls; + /// Iterator types. pub mod iter; diff --git a/src/serde_impls.rs b/src/serde_impls.rs new file mode 100644 index 00000000..42e5ebbd --- /dev/null +++ b/src/serde_impls.rs @@ -0,0 +1,216 @@ +use crate::{HashMap, HashMapRef, HashSet, HashSetRef}; +use serde::{ + de::{MapAccess, SeqAccess, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::fmt::{self, Formatter}; +use std::hash::{BuildHasher, Hash}; +use std::marker::PhantomData; + +struct HashMapVisitor { + key_marker: PhantomData, + value_marker: PhantomData, + hash_builder_marker: PhantomData, +} + +impl Serialize for HashMapRef<'_, K, V, S> +where + K: Serialize, + V: Serialize, +{ + fn serialize(&self, serializer: Sr) -> Result + where + Sr: Serializer, + { + serializer.collect_map(self.iter()) + } +} + +impl Serialize for HashMap +where + K: Serialize, + V: Serialize, +{ + fn serialize(&self, serializer: Sr) -> Result + where + Sr: Serializer, + { + self.pin().serialize(serializer) + } +} + +impl<'de, K, V, S> Deserialize<'de> for HashMap +where + K: 'static + Deserialize<'de> + Send + Sync + Hash + Clone + Eq, + V: 'static + Deserialize<'de> + Send + Sync + Eq, + S: Default + BuildHasher, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_map(HashMapVisitor::new()) + } +} + +impl HashMapVisitor { + pub(crate) fn new() -> Self { + Self { + key_marker: PhantomData, + value_marker: PhantomData, + hash_builder_marker: PhantomData, + } + } +} + +impl<'de, K, V, S> Visitor<'de> for HashMapVisitor +where + K: 'static + Deserialize<'de> + Send + Sync + Hash + Clone + Eq, + V: 'static + Deserialize<'de> + Send + Sync + Eq, + S: Default + BuildHasher, +{ + type Value = HashMap; + + fn expecting(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "a map") + } + + fn visit_map(self, mut access: M) -> Result + where + M: MapAccess<'de>, + { + let map = match access.size_hint() { + Some(n) => HashMap::with_capacity_and_hasher(n, S::default()), + None => HashMap::with_hasher(S::default()), + }; + let guard = map.guard(); + + while let Some((key, value)) = access.next_entry()? { + if let Some(_old_value) = map.insert(key, value, &guard) { + unreachable!("Serialized map held two values with the same key"); + } + } + + Ok(map) + } +} + +impl Serialize for HashSetRef<'_, T, S> +where + T: Serialize, +{ + fn serialize(&self, serilizer: Sr) -> Result + where + Sr: Serializer, + { + serilizer.collect_seq(self.iter()) + } +} + +impl Serialize for HashSet +where + T: Serialize, +{ + fn serialize(&self, serializer: Sr) -> Result + where + Sr: Serializer, + { + self.pin().serialize(serializer) + } +} + +impl<'de, T, S> Deserialize<'de> for HashSet +where + T: 'static + Deserialize<'de> + Send + Sync + Hash + Clone + Eq, + S: Default + BuildHasher, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_seq(HashSetVisitor::new()) + } +} + +struct HashSetVisitor { + type_marker: PhantomData, + hash_builder_marker: PhantomData, +} + +impl HashSetVisitor { + pub(crate) fn new() -> Self { + Self { + type_marker: PhantomData, + hash_builder_marker: PhantomData, + } + } +} + +impl<'de, T, S> Visitor<'de> for HashSetVisitor +where + T: 'static + Deserialize<'de> + Send + Sync + Hash + Clone + Eq, + S: Default + BuildHasher, +{ + type Value = HashSet; + + fn expecting(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "a set") + } + + fn visit_seq(self, mut access: A) -> Result + where + A: SeqAccess<'de>, + { + let set = HashSet::default(); + let guard = set.guard(); + + while let Some(value) = access.next_element()? { + let _ = set.insert(value, &guard); + } + + Ok(set) + } +} + +#[cfg(test)] +mod test { + use crate::{HashMap, HashSet}; + + #[test] + fn test_map() { + let map: HashMap = HashMap::with_capacity(5); + let guard = map.guard(); + + let _ = map.insert(0, 4, &guard); + let _ = map.insert(1, 3, &guard); + let _ = map.insert(2, 2, &guard); + let _ = map.insert(3, 1, &guard); + let _ = map.insert(4, 0, &guard); + + let serialized = serde_json::to_string(&map).expect("Couldn't serialize map"); + + let deserialized: HashMap = + serde_json::from_str(&serialized).expect("Couldn't deserialize map"); + + assert_eq!(map, deserialized); + } + + #[test] + fn test_set() { + let set: HashSet = HashSet::with_capacity(5); + let guard = set.guard(); + + let _ = set.insert(0, &guard); + let _ = set.insert(1, &guard); + let _ = set.insert(2, &guard); + let _ = set.insert(3, &guard); + let _ = set.insert(4, &guard); + + let serialized = serde_json::to_string(&set).expect("Couldn't serialize map"); + + let deserialized: HashSet = + serde_json::from_str(&serialized).expect("Couldn't deserialize map"); + + assert_eq!(set, deserialized); + } +} diff --git a/src/set.rs b/src/set.rs index c0cb46cd..cdd2c690 100644 --- a/src/set.rs +++ b/src/set.rs @@ -2,15 +2,14 @@ //! //! See `HashSet` for details. +use crate::epoch::Guard; +use crate::iter::Keys; +use crate::HashMap; use std::borrow::Borrow; use std::fmt::{self, Debug, Formatter}; use std::hash::{BuildHasher, Hash}; use std::iter::FromIterator; -use crate::epoch::Guard; -use crate::iter::Keys; -use crate::HashMap; - /// A concurrent hash set implemented as a `HashMap` where the value is `()`. /// /// # Examples