Skip to content

Commit

Permalink
tls_codec: Introduce helper macro for conditional deserialization (#1330
Browse files Browse the repository at this point in the history
)
  • Loading branch information
kkohbrok authored Jan 25, 2024
1 parent 3cace2c commit 14caf10
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 5 deletions.
119 changes: 115 additions & 4 deletions tls_codec/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,32 @@
//! DeserializableExampleStruct::tls_deserialize(&mut serialized.as_slice()).unwrap();
//! # }
//! ```
//!
//! The helper macro `#[tls_codec(cd_field)]` can be used to mark a field as
//! conditionally deserializable, thus allowing nested conditionally
//! deserializable structs.
//!
//! ```
//! # #[cfg(all(feature = "conditional_deserialization", feature = "std"))]
//! # {
//! use tls_codec::{Serialize, Deserialize};
//! use tls_codec_derive::{TlsSerialize, TlsSize, conditionally_deserializable};
//!
//! #[conditionally_deserializable]
//! #[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
//! struct ExampleStruct {
//! a: u8,
//! b: u16,
//! }
//!
//! #[conditionally_deserializable]
//! #[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
//! struct NestedExampleStruct {
//! #[tls_codec(cd_field)]
//! example_struct: ExampleStruct,
//! }
//! # }
//! ```

extern crate proc_macro;
extern crate proc_macro2;
Expand Down Expand Up @@ -324,6 +350,8 @@ enum TlsAttr {
/// This is required to populate the field with a known
/// value during deserialization.
Skip,
#[cfg(feature = "conditional_deserialization")]
CdField,
}

impl TlsAttr {
Expand All @@ -332,6 +360,8 @@ impl TlsAttr {
TlsAttr::With(_) => "with",
TlsAttr::Discriminant(_) => "discriminant",
TlsAttr::Skip => "skip",
#[cfg(feature = "conditional_deserialization")]
TlsAttr::CdField => "cd_field",
}
}

Expand Down Expand Up @@ -392,6 +422,8 @@ impl TlsAttr {
if let Some(ident) = path.get_ident() {
match ident.to_string().to_ascii_lowercase().as_ref() {
"skip" => Ok(TlsAttr::Skip),
#[cfg(feature = "conditional_deserialization")]
"cd_field" => Ok(TlsAttr::CdField),
_ => Err(syn::Error::new_spanned(
ident,
format!("Unexpected identifier {}", ident),
Expand Down Expand Up @@ -454,6 +486,25 @@ fn function_skip(field: &Field) -> Result<bool> {
Ok(skip)
}

/// Process all attributes of a field and return a single, true or false, `cd_field` value.
/// This function will return an error in the case of multiple `cd` attributes.
#[cfg(feature = "conditional_deserialization")]
fn function_cd(field: &Field) -> Result<bool> {
let skip = TlsAttr::parse_multi(&field.attrs)?
.into_iter()
.try_fold(None, |skip, attr| match (skip, attr) {
(None, TlsAttr::CdField) => Ok(Some(true)),
(Some(_), TlsAttr::CdField) => Err(syn::Error::new(
Span::call_site(),
"Attribute `cd_field` specified more than once",
)),
(skip, _) => Ok(skip),
})?
.unwrap_or(false);

Ok(skip)
}

/// Gets the serialization discriminant if specified.
fn discriminant_value(attrs: &[Attribute]) -> Result<Option<DiscriminantValue>> {
TlsAttr::parse_multi(attrs)?
Expand Down Expand Up @@ -986,6 +1037,8 @@ fn restrict_conditional_generic(
) -> (TokenStream2, TokenStream2) {
let impl_generics = quote! { #impl_generics }
.to_string()
// Make string replacement easier by replacing newlines with spaces.
.replace('\n', " ")
.replace(" const IS_DESERIALIZABLE : bool ", "")
.replace("<>", "")
.parse()
Expand Down Expand Up @@ -1276,19 +1329,67 @@ pub fn conditionally_deserializable(
impl_conditionally_deserializable(annotated_item).into()
}

#[cfg(feature = "conditional_deserialization")]
fn set_cd_fields_generic(
mut item_struct: ItemStruct,
value: proc_macro2::TokenStream,
) -> ItemStruct {
use syn::{
parse::{Parse, Parser},
AngleBracketedGenericArguments, PathArguments,
};

item_struct.fields.iter_mut().for_each(|field| {
if function_cd(field).unwrap() {
// We only do this if it's a simple (Path) type
if let Type::Path(path) = &mut field.ty {
// If there is already an AngleBracketedGenericArguments, we just add the const generic at the end.
if let Some(segment) = path.path.segments.last_mut() {
if let PathArguments::AngleBracketed(ref mut argument) = &mut segment.arguments
{
argument.args.push(parse_quote! {#value});
} else {
// If there is no AngleBracketedGenericArguments, we create one and add the const generic.
let parser = AngleBracketedGenericArguments::parse;
let angle_bracketed = parser.parse(TokenStream::from(quote!(<#value>)));
let angle_bracketed = angle_bracketed.unwrap();
segment.arguments = PathArguments::AngleBracketed(angle_bracketed);
}
}
} else {
panic!("Only simple types are supported for conditional deserialization.");
}
}
});
item_struct
}

#[cfg(feature = "conditional_deserialization")]
fn impl_conditionally_deserializable(mut annotated_item: ItemStruct) -> TokenStream2 {
let deserializable_const_generic: ConstParam = parse_quote! {const IS_DESERIALIZABLE: bool};
// Get all the original generics of the annotated item before we modify
// them.
let original_annotated_item = annotated_item.clone();
let (_, original_ty_generics, _) = original_annotated_item.generics.split_for_impl();
// Add the DESERIALIZABLE const generic to the struct
annotated_item
.generics
.params
.push(deserializable_const_generic.into());
// Look through struct fields and if a field has an `cd_field` attribute,
// add the `IS_DESERIALIZABLE` const generic to the field type (set to true
// for the deserialize implementation).
let item_for_deserialize_impl = set_cd_fields_generic(annotated_item.clone(), quote!(true));
// Look through struct fields and if a field has an `cd_field` attribute,
// add the `IS_DESERIALIZABLE` const generic to the field type (set to the
// value IS_DESERIALIZABLE from the struct definition).
let annotated_item = set_cd_fields_generic(annotated_item, quote!(IS_DESERIALIZABLE));

// Derive both TlsDeserialize and TlsDeserializeBytes
let deserialize_bytes_implementation =
impl_deserialize_bytes(parse_ast(annotated_item.clone().into()).unwrap());
impl_deserialize_bytes(parse_ast(item_for_deserialize_impl.clone().into()).unwrap());
let deserialize_implementation =
impl_deserialize(parse_ast(annotated_item.clone().into()).unwrap());
impl_deserialize(parse_ast(item_for_deserialize_impl.clone().into()).unwrap());
let (impl_generics, ty_generics, _) = annotated_item.generics.split_for_impl();
// Patch generics for use by the type aliases
let (_deserializable_impl_generics, deserializable_ty_generics) =
Expand All @@ -1306,11 +1407,21 @@ fn impl_conditionally_deserializable(mut annotated_item: ItemStruct) -> TokenStr
Span::call_site(),
);
let annotated_item_visibility = annotated_item.vis.clone();
let doc_string_deserializable = format!(
"Alias for the deserializable version of the [`{}`].",
annotated_item_ident
);
let doc_string_undeserializable = format!(
"Alias for the version of the [`{}`] that cannot be deserialized.",
annotated_item_ident
);
quote! {
#annotated_item

#annotated_item_visibility type #undeserializable_ident = #annotated_item_ident #undeserializable_ty_generics;
#annotated_item_visibility type #deserializable_ident = #annotated_item_ident #deserializable_ty_generics;
#[doc = #doc_string_deserializable]
#annotated_item_visibility type #undeserializable_ident #original_ty_generics = #annotated_item_ident #undeserializable_ty_generics;
#[doc = #doc_string_undeserializable]
#annotated_item_visibility type #deserializable_ident #original_ty_generics = #annotated_item_ident #deserializable_ty_generics;

#deserialize_implementation

Expand Down
40 changes: 39 additions & 1 deletion tls_codec/derive/tests/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ fn type_with_unknowns() {

#[cfg(feature = "conditional_deserialization")]
mod conditional_deserialization {
use tls_codec::{Deserialize, Serialize};
use tls_codec::{Deserialize, DeserializeBytes, Serialize, Size};
use tls_codec_derive::{conditionally_deserializable, TlsSerialize, TlsSize};

#[test]
Expand All @@ -549,6 +549,13 @@ mod conditional_deserialization {
assert_eq!(deserializable_struct.a, undeserializable_struct.a);
assert_eq!(deserializable_struct.b, undeserializable_struct.b);

#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct NestedExampleStruct {
#[tls_codec(cd_field)]
nested_field: ExampleStruct,
}

#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct SecondExampleStruct {
Expand All @@ -565,11 +572,42 @@ mod conditional_deserialization {
a: u8,
b: u16,
}

let undeserializable_struct = UndeserializableExampleStruct { a: 1, b: 2 };
let serialized = undeserializable_struct.tls_serialize_detached().unwrap();
let deserializable_struct =
DeserializableExampleStruct::tls_deserialize_exact(&*serialized).unwrap();
assert_eq!(deserializable_struct.a, undeserializable_struct.a);
assert_eq!(deserializable_struct.b, undeserializable_struct.b);
}

#[test]
fn nested_conditionally_deserializable() {
#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct ExampleStruct {
a: u8,
b: u16,
}

#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct NestedExampleStruct {
#[tls_codec(cd_field)]
nested_field: ExampleStruct,
}

#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct GenericExampleStruct<T: Serialize + Size + Deserialize + DeserializeBytes> {
a: T,
}

#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct NestedGenericExampleStruct<T: Serialize + Size + Deserialize + DeserializeBytes> {
#[tls_codec(cd_field)]
nested_field: GenericExampleStruct<T>,
}
}
}

0 comments on commit 14caf10

Please sign in to comment.