From 14caf10145ed82b2191eedcb5b7391563ad25470 Mon Sep 17 00:00:00 2001 From: Konrad Kohbrok Date: Thu, 25 Jan 2024 15:33:21 +0100 Subject: [PATCH] tls_codec: Introduce helper macro for conditional deserialization (#1330) --- tls_codec/derive/src/lib.rs | 119 +++++++++++++++++++++++++++++-- tls_codec/derive/tests/decode.rs | 40 ++++++++++- 2 files changed, 154 insertions(+), 5 deletions(-) diff --git a/tls_codec/derive/src/lib.rs b/tls_codec/derive/src/lib.rs index d9653366b..29ad53388 100644 --- a/tls_codec/derive/src/lib.rs +++ b/tls_codec/derive/src/lib.rs @@ -234,6 +234,32 @@ //! DeserializableExampleStruct::tls_deserialize(&mut serialized.as_slice()).unwrap(); //! # } //! ``` +//! +//! The helper macro `#[tls_codec(cd_field)]` can be used to mark a field as +//! conditionally deserializable, thus allowing nested conditionally +//! deserializable structs. +//! +//! ``` +//! # #[cfg(all(feature = "conditional_deserialization", feature = "std"))] +//! # { +//! use tls_codec::{Serialize, Deserialize}; +//! use tls_codec_derive::{TlsSerialize, TlsSize, conditionally_deserializable}; +//! +//! #[conditionally_deserializable] +//! #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] +//! struct ExampleStruct { +//! a: u8, +//! b: u16, +//! } +//! +//! #[conditionally_deserializable] +//! #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] +//! struct NestedExampleStruct { +//! #[tls_codec(cd_field)] +//! example_struct: ExampleStruct, +//! } +//! # } +//! ``` extern crate proc_macro; extern crate proc_macro2; @@ -324,6 +350,8 @@ enum TlsAttr { /// This is required to populate the field with a known /// value during deserialization. Skip, + #[cfg(feature = "conditional_deserialization")] + CdField, } impl TlsAttr { @@ -332,6 +360,8 @@ impl TlsAttr { TlsAttr::With(_) => "with", TlsAttr::Discriminant(_) => "discriminant", TlsAttr::Skip => "skip", + #[cfg(feature = "conditional_deserialization")] + TlsAttr::CdField => "cd_field", } } @@ -392,6 +422,8 @@ impl TlsAttr { if let Some(ident) = path.get_ident() { match ident.to_string().to_ascii_lowercase().as_ref() { "skip" => Ok(TlsAttr::Skip), + #[cfg(feature = "conditional_deserialization")] + "cd_field" => Ok(TlsAttr::CdField), _ => Err(syn::Error::new_spanned( ident, format!("Unexpected identifier {}", ident), @@ -454,6 +486,25 @@ fn function_skip(field: &Field) -> Result { Ok(skip) } +/// Process all attributes of a field and return a single, true or false, `cd_field` value. +/// This function will return an error in the case of multiple `cd` attributes. +#[cfg(feature = "conditional_deserialization")] +fn function_cd(field: &Field) -> Result { + let skip = TlsAttr::parse_multi(&field.attrs)? + .into_iter() + .try_fold(None, |skip, attr| match (skip, attr) { + (None, TlsAttr::CdField) => Ok(Some(true)), + (Some(_), TlsAttr::CdField) => Err(syn::Error::new( + Span::call_site(), + "Attribute `cd_field` specified more than once", + )), + (skip, _) => Ok(skip), + })? + .unwrap_or(false); + + Ok(skip) +} + /// Gets the serialization discriminant if specified. fn discriminant_value(attrs: &[Attribute]) -> Result> { TlsAttr::parse_multi(attrs)? @@ -986,6 +1037,8 @@ fn restrict_conditional_generic( ) -> (TokenStream2, TokenStream2) { let impl_generics = quote! { #impl_generics } .to_string() + // Make string replacement easier by replacing newlines with spaces. + .replace('\n', " ") .replace(" const IS_DESERIALIZABLE : bool ", "") .replace("<>", "") .parse() @@ -1276,19 +1329,67 @@ pub fn conditionally_deserializable( impl_conditionally_deserializable(annotated_item).into() } +#[cfg(feature = "conditional_deserialization")] +fn set_cd_fields_generic( + mut item_struct: ItemStruct, + value: proc_macro2::TokenStream, +) -> ItemStruct { + use syn::{ + parse::{Parse, Parser}, + AngleBracketedGenericArguments, PathArguments, + }; + + item_struct.fields.iter_mut().for_each(|field| { + if function_cd(field).unwrap() { + // We only do this if it's a simple (Path) type + if let Type::Path(path) = &mut field.ty { + // If there is already an AngleBracketedGenericArguments, we just add the const generic at the end. + if let Some(segment) = path.path.segments.last_mut() { + if let PathArguments::AngleBracketed(ref mut argument) = &mut segment.arguments + { + argument.args.push(parse_quote! {#value}); + } else { + // If there is no AngleBracketedGenericArguments, we create one and add the const generic. + let parser = AngleBracketedGenericArguments::parse; + let angle_bracketed = parser.parse(TokenStream::from(quote!(<#value>))); + let angle_bracketed = angle_bracketed.unwrap(); + segment.arguments = PathArguments::AngleBracketed(angle_bracketed); + } + } + } else { + panic!("Only simple types are supported for conditional deserialization."); + } + } + }); + item_struct +} + #[cfg(feature = "conditional_deserialization")] fn impl_conditionally_deserializable(mut annotated_item: ItemStruct) -> TokenStream2 { let deserializable_const_generic: ConstParam = parse_quote! {const IS_DESERIALIZABLE: bool}; + // Get all the original generics of the annotated item before we modify + // them. + let original_annotated_item = annotated_item.clone(); + let (_, original_ty_generics, _) = original_annotated_item.generics.split_for_impl(); // Add the DESERIALIZABLE const generic to the struct annotated_item .generics .params .push(deserializable_const_generic.into()); + // Look through struct fields and if a field has an `cd_field` attribute, + // add the `IS_DESERIALIZABLE` const generic to the field type (set to true + // for the deserialize implementation). + let item_for_deserialize_impl = set_cd_fields_generic(annotated_item.clone(), quote!(true)); + // Look through struct fields and if a field has an `cd_field` attribute, + // add the `IS_DESERIALIZABLE` const generic to the field type (set to the + // value IS_DESERIALIZABLE from the struct definition). + let annotated_item = set_cd_fields_generic(annotated_item, quote!(IS_DESERIALIZABLE)); + // Derive both TlsDeserialize and TlsDeserializeBytes let deserialize_bytes_implementation = - impl_deserialize_bytes(parse_ast(annotated_item.clone().into()).unwrap()); + impl_deserialize_bytes(parse_ast(item_for_deserialize_impl.clone().into()).unwrap()); let deserialize_implementation = - impl_deserialize(parse_ast(annotated_item.clone().into()).unwrap()); + impl_deserialize(parse_ast(item_for_deserialize_impl.clone().into()).unwrap()); let (impl_generics, ty_generics, _) = annotated_item.generics.split_for_impl(); // Patch generics for use by the type aliases let (_deserializable_impl_generics, deserializable_ty_generics) = @@ -1306,11 +1407,21 @@ fn impl_conditionally_deserializable(mut annotated_item: ItemStruct) -> TokenStr Span::call_site(), ); let annotated_item_visibility = annotated_item.vis.clone(); + let doc_string_deserializable = format!( + "Alias for the deserializable version of the [`{}`].", + annotated_item_ident + ); + let doc_string_undeserializable = format!( + "Alias for the version of the [`{}`] that cannot be deserialized.", + annotated_item_ident + ); quote! { #annotated_item - #annotated_item_visibility type #undeserializable_ident = #annotated_item_ident #undeserializable_ty_generics; - #annotated_item_visibility type #deserializable_ident = #annotated_item_ident #deserializable_ty_generics; + #[doc = #doc_string_deserializable] + #annotated_item_visibility type #undeserializable_ident #original_ty_generics = #annotated_item_ident #undeserializable_ty_generics; + #[doc = #doc_string_undeserializable] + #annotated_item_visibility type #deserializable_ident #original_ty_generics = #annotated_item_ident #deserializable_ty_generics; #deserialize_implementation diff --git a/tls_codec/derive/tests/decode.rs b/tls_codec/derive/tests/decode.rs index feb487341..c594eae49 100644 --- a/tls_codec/derive/tests/decode.rs +++ b/tls_codec/derive/tests/decode.rs @@ -530,7 +530,7 @@ fn type_with_unknowns() { #[cfg(feature = "conditional_deserialization")] mod conditional_deserialization { - use tls_codec::{Deserialize, Serialize}; + use tls_codec::{Deserialize, DeserializeBytes, Serialize, Size}; use tls_codec_derive::{conditionally_deserializable, TlsSerialize, TlsSize}; #[test] @@ -549,6 +549,13 @@ mod conditional_deserialization { assert_eq!(deserializable_struct.a, undeserializable_struct.a); assert_eq!(deserializable_struct.b, undeserializable_struct.b); + #[conditionally_deserializable] + #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] + struct NestedExampleStruct { + #[tls_codec(cd_field)] + nested_field: ExampleStruct, + } + #[conditionally_deserializable] #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] struct SecondExampleStruct { @@ -565,6 +572,7 @@ mod conditional_deserialization { a: u8, b: u16, } + let undeserializable_struct = UndeserializableExampleStruct { a: 1, b: 2 }; let serialized = undeserializable_struct.tls_serialize_detached().unwrap(); let deserializable_struct = @@ -572,4 +580,34 @@ mod conditional_deserialization { assert_eq!(deserializable_struct.a, undeserializable_struct.a); assert_eq!(deserializable_struct.b, undeserializable_struct.b); } + + #[test] + fn nested_conditionally_deserializable() { + #[conditionally_deserializable] + #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] + struct ExampleStruct { + a: u8, + b: u16, + } + + #[conditionally_deserializable] + #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] + struct NestedExampleStruct { + #[tls_codec(cd_field)] + nested_field: ExampleStruct, + } + + #[conditionally_deserializable] + #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] + struct GenericExampleStruct { + a: T, + } + + #[conditionally_deserializable] + #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] + struct NestedGenericExampleStruct { + #[tls_codec(cd_field)] + nested_field: GenericExampleStruct, + } + } }