Skip to content

Commit

Permalink
improve bytes serde
Browse files Browse the repository at this point in the history
  • Loading branch information
polazarus committed Aug 10, 2023
1 parent 8cd1727 commit 74b8bb2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 98 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ rust-version = "1.64.0"

[features]
unstable = []
serde = ["dep:serde", "dep:serde_bytes"]

[[bench]]
name = "competition"
Expand All @@ -25,8 +26,10 @@ arcstr = "1.1.5"
fastrand = "1.9.0"
flexstr = "0.9.2"
imstr = "0.2.0"
serde_test = "1.0.163"
serde_test = "1.0"
serde = { version = "1.0.60", features = ["derive"] }
serde_json = "1.0"

[dependencies]
serde = { version = "1.0.60", optional = true }
serde_bytes = { version = "0.11", optional = true }
129 changes: 32 additions & 97 deletions src/bytes/serde.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::marker::PhantomData;
use std::borrow::Cow;

use serde::de::Visitor;
use serde::{Deserialize, Serialize};

use super::HipByt;
Expand All @@ -18,49 +17,6 @@ where
}
}

#[derive(Clone, Copy, Debug)]
struct BytesVisitor;

impl<'de> Visitor<'de> for BytesVisitor {
type Value = Vec<u8>;

fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v)
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.into())
}

fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.into())
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut v = seq.size_hint().map_or_else(Vec::new, Vec::with_capacity);
while let Some(e) = seq.next_element()? {
v.push(e);
}
Ok(v)
}

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "bytes")
}
}

impl<'de, 'borrow, B> Deserialize<'de> for HipByt<'borrow, B>
where
B: Backend,
Expand All @@ -69,76 +25,39 @@ where
where
D: serde::Deserializer<'de>,
{
let v = deserializer.deserialize_byte_buf(BytesVisitor)?;
let v: Vec<u8> = serde_bytes::deserialize(deserializer)?;
Ok(Self::from(v))
}
}

#[derive(Clone, Copy, Debug)]
struct BorrowingByteVisitor<B>(PhantomData<B>);

impl<'de, B> Visitor<'de> for BorrowingByteVisitor<B>
where
B: Backend,
{
type Value = HipByt<'de, B>;

fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.into())
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.into())
}

fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(HipByt::borrowed(v))
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut v = seq.size_hint().map_or_else(Vec::new, Vec::with_capacity);
while let Some(e) = seq.next_element()? {
v.push(e);
}
Ok(v.into())
}

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "bytes")
}
}

/// Deserializes a `HipByt` as a borrow if possible.
///
/// ```rust
/// # use serde::Deserialize;
/// # use serde_json;
/// use hipstr::bytes::HipByt;
/// use hipstr::Local;
/// #[derive(Deserialize)]
/// struct MyStruct<'a> {
/// #[serde(borrow, deserialize_with = "hipstr::bytes::serde::borrowing_deserialize")]
/// field: HipByt<'a, Local>,
/// }
/// # fn main() {}
/// # fn main() {
/// let s: MyStruct = serde_json::from_str(r#"{"field": "abc"}"#).unwrap();
/// assert!(s.field.is_borrowed());
/// # }
/// ```
pub fn borrowing_deserialize<'de, D, B>(deserializer: D) -> Result<HipByt<'de, B>, D::Error>
///
/// # Errors
///
/// Returns a deserializer if either the serialization is incorrect or an unexpected value is encountered.
pub fn borrowing_deserialize<'de: 'a, 'a, D, B>(deserializer: D) -> Result<HipByt<'a, B>, D::Error>
where
D: serde::Deserializer<'de>,
B: Backend,
{
deserializer.deserialize_byte_buf(BorrowingByteVisitor(PhantomData))
let cow: Cow<'de, [u8]> = serde_bytes::Deserialize::deserialize(deserializer)?;
Ok(HipByt::from(cow))
}

#[cfg(test)]
Expand All @@ -147,6 +66,7 @@ mod tests {
assert_de_tokens, assert_de_tokens_error, assert_ser_tokens, assert_tokens, Token,
};

use crate::bytes::serde::borrowing_deserialize;
use crate::HipByt;

#[test]
Expand Down Expand Up @@ -177,8 +97,23 @@ mod tests {
#[test]
fn test_de_error() {
assert_de_tokens_error::<HipByt>(
&[Token::Str("")],
"invalid type: string \"\", expected bytes",
&[Token::F32(0.0)],
"invalid type: floating point `0`, expected byte array",
);
}

#[test]
fn test_serde_borrowing() {
use serde::de::Deserialize;
use serde_json::Value;

use super::super::HipByt;
use crate::Local;

let v = Value::from("abcdefghijklmnopqrstuvwxyz");
let h1: HipByt<'_, Local> = borrowing_deserialize(&v).unwrap();
let h2: HipByt<'_, Local> = Deserialize::deserialize(&v).unwrap();
assert!(h1.is_borrowed());
assert!(!h2.is_borrowed());
}
}

0 comments on commit 74b8bb2

Please sign in to comment.