Skip to content

Commit

Permalink
Correctly process flatten fields in enum variants
Browse files Browse the repository at this point in the history
- Fix incorrect deserialization of variants that doesn't contain flatten field when other contains
- Fix a panic when deriving `Deserialize` for an enum with tuple and struct with flatten field

Fixes (2):
    regression::issue2565::simple_variant
    regression::issue1904 (compilation)
  • Loading branch information
Mingun committed Jul 23, 2024
1 parent 6ab20bb commit 2b88ec0
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 31 deletions.
93 changes: 62 additions & 31 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,21 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment {
} else if let attr::Identifier::No = cont.attrs.identifier() {
match &cont.data {
Data::Enum(variants) => deserialize_enum(params, variants, &cont.attrs),
Data::Struct(Style::Struct, fields) => {
deserialize_struct(params, fields, &cont.attrs, StructForm::Struct)
}
Data::Struct(Style::Struct, fields) => deserialize_struct(
params,
fields,
&cont.attrs,
cont.attrs.has_flatten(),
StructForm::Struct,
),
Data::Struct(Style::Tuple, fields) | Data::Struct(Style::Newtype, fields) => {
deserialize_tuple(params, fields, &cont.attrs, TupleForm::Tuple)
deserialize_tuple(
params,
fields,
&cont.attrs,
cont.attrs.has_flatten(),
TupleForm::Tuple,
)
}
Data::Struct(Style::Unit, _) => deserialize_unit_struct(params, &cont.attrs),
}
Expand Down Expand Up @@ -459,9 +469,13 @@ fn deserialize_tuple(
params: &Parameters,
fields: &[Field],
cattrs: &attr::Container,
has_flatten: bool,
form: TupleForm,
) -> Fragment {
assert!(!cattrs.has_flatten());
assert!(
!has_flatten,
"tuples and tuple variants cannot have flatten fields"
);

let field_count = fields
.iter()
Expand Down Expand Up @@ -579,7 +593,10 @@ fn deserialize_tuple_in_place(
fields: &[Field],
cattrs: &attr::Container,
) -> Fragment {
assert!(!cattrs.has_flatten());
assert!(
!cattrs.has_flatten(),
"tuples and tuple variants cannot have flatten fields"
);

let field_count = fields
.iter()
Expand Down Expand Up @@ -910,6 +927,7 @@ fn deserialize_struct(
params: &Parameters,
fields: &[Field],
cattrs: &attr::Container,
has_flatten: bool,
form: StructForm,
) -> Fragment {
let this_type = &params.this_type;
Expand Down Expand Up @@ -958,13 +976,13 @@ fn deserialize_struct(
)
})
.collect();
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs);
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs, has_flatten);

// untagged struct variants do not get a visit_seq method. The same applies to
// structs that only have a map representation.
let visit_seq = match form {
StructForm::Untagged(..) => None,
_ if cattrs.has_flatten() => None,
_ if has_flatten => None,
_ => {
let mut_seq = if field_names_idents.is_empty() {
quote!(_)
Expand All @@ -987,10 +1005,16 @@ fn deserialize_struct(
})
}
};
let visit_map = Stmts(deserialize_map(&type_path, params, fields, cattrs));
let visit_map = Stmts(deserialize_map(
&type_path,
params,
fields,
cattrs,
has_flatten,
));

let visitor_seed = match form {
StructForm::ExternallyTagged(..) if cattrs.has_flatten() => Some(quote! {
StructForm::ExternallyTagged(..) if has_flatten => Some(quote! {
impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Visitor #de_ty_generics #where_clause {
type Value = #this_type #ty_generics;

Expand All @@ -1005,7 +1029,7 @@ fn deserialize_struct(
_ => None,
};

let fields_stmt = if cattrs.has_flatten() {
let fields_stmt = if has_flatten {
None
} else {
let field_names = field_names_idents
Expand All @@ -1025,7 +1049,7 @@ fn deserialize_struct(
}
};
let dispatch = match form {
StructForm::Struct if cattrs.has_flatten() => quote! {
StructForm::Struct if has_flatten => quote! {
_serde::Deserializer::deserialize_map(__deserializer, #visitor_expr)
},
StructForm::Struct => {
Expand All @@ -1034,7 +1058,7 @@ fn deserialize_struct(
_serde::Deserializer::deserialize_struct(__deserializer, #type_name, FIELDS, #visitor_expr)
}
}
StructForm::ExternallyTagged(_) if cattrs.has_flatten() => quote! {
StructForm::ExternallyTagged(_) if has_flatten => quote! {
_serde::de::VariantAccess::newtype_variant_seed(__variant, #visitor_expr)
},
StructForm::ExternallyTagged(_) => quote! {
Expand Down Expand Up @@ -1116,7 +1140,7 @@ fn deserialize_struct_in_place(
})
.collect();

let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs);
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs, false);

let mut_seq = if field_names_idents.is_empty() {
quote!(_)
Expand Down Expand Up @@ -1210,10 +1234,7 @@ fn deserialize_homogeneous_enum(
}
}

fn prepare_enum_variant_enum(
variants: &[Variant],
cattrs: &attr::Container,
) -> (TokenStream, Stmts) {
fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) {
let mut deserialized_variants = variants
.iter()
.enumerate()
Expand Down Expand Up @@ -1247,7 +1268,7 @@ fn prepare_enum_variant_enum(

let variant_visitor = Stmts(deserialize_generated_identifier(
&variant_names_idents,
cattrs,
false, // variant identifiers does not depend on the presence of flatten fields
true,
None,
fallthrough,
Expand All @@ -1270,7 +1291,7 @@ fn deserialize_externally_tagged_enum(
let expecting = format!("enum {}", params.type_name());
let expecting = cattrs.expecting().unwrap_or(&expecting);

let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs);
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);

// Match arms to extract a variant from a string
let variant_arms = variants
Expand Down Expand Up @@ -1355,7 +1376,7 @@ fn deserialize_internally_tagged_enum(
cattrs: &attr::Container,
tag: &str,
) -> Fragment {
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs);
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);

// Match arms to extract a variant from a string
let variant_arms = variants
Expand Down Expand Up @@ -1409,7 +1430,7 @@ fn deserialize_adjacently_tagged_enum(
split_with_de_lifetime(params);
let delife = params.borrowed.de_lifetime();

let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs);
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);

let variant_arms: &Vec<_> = &variants
.iter()
Expand Down Expand Up @@ -1810,12 +1831,14 @@ fn deserialize_externally_tagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
TupleForm::ExternallyTagged(variant_ident),
),
Style::Struct => deserialize_struct(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::ExternallyTagged(variant_ident),
),
}
Expand Down Expand Up @@ -1859,6 +1882,7 @@ fn deserialize_internally_tagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::InternallyTagged(variant_ident, deserializer),
),
Style::Tuple => unreachable!("checked in serde_derive_internals"),
Expand Down Expand Up @@ -1909,12 +1933,14 @@ fn deserialize_untagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
TupleForm::Untagged(variant_ident, deserializer),
),
Style::Struct => deserialize_struct(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::Untagged(variant_ident, deserializer),
),
}
Expand Down Expand Up @@ -1985,7 +2011,7 @@ fn deserialize_untagged_newtype_variant(

fn deserialize_generated_identifier(
fields: &[(&str, Ident, &BTreeSet<String>)],
cattrs: &attr::Container,
has_flatten: bool,
is_variant: bool,
ignore_variant: Option<TokenStream>,
fallthrough: Option<TokenStream>,
Expand All @@ -1999,11 +2025,11 @@ fn deserialize_generated_identifier(
is_variant,
fallthrough,
None,
!is_variant && cattrs.has_flatten(),
!is_variant && has_flatten,
None,
));

let lifetime = if !is_variant && cattrs.has_flatten() {
let lifetime = if !is_variant && has_flatten {
Some(quote!(<'de>))
} else {
None
Expand Down Expand Up @@ -2043,8 +2069,9 @@ fn deserialize_generated_identifier(
fn deserialize_field_identifier(
fields: &[(&str, Ident, &BTreeSet<String>)],
cattrs: &attr::Container,
has_flatten: bool,
) -> Stmts {
let (ignore_variant, fallthrough) = if cattrs.has_flatten() {
let (ignore_variant, fallthrough) = if has_flatten {
let ignore_variant = quote!(__other(_serde::__private::de::Content<'de>),);
let fallthrough = quote!(_serde::__private::Ok(__Field::__other(__value)));
(Some(ignore_variant), Some(fallthrough))
Expand All @@ -2058,7 +2085,7 @@ fn deserialize_field_identifier(

Stmts(deserialize_generated_identifier(
fields,
cattrs,
has_flatten,
false,
ignore_variant,
fallthrough,
Expand Down Expand Up @@ -2460,6 +2487,7 @@ fn deserialize_map(
params: &Parameters,
fields: &[Field],
cattrs: &attr::Container,
has_flatten: bool,
) -> Fragment {
// Create the field names for the fields.
let fields_names: Vec<_> = fields
Expand All @@ -2480,7 +2508,7 @@ fn deserialize_map(
});

// Collect contents for flatten fields into a buffer
let let_collect = if cattrs.has_flatten() {
let let_collect = if has_flatten {
Some(quote! {
let mut __collect = _serde::__private::Vec::<_serde::__private::Option<(
_serde::__private::de::Content,
Expand Down Expand Up @@ -2532,7 +2560,7 @@ fn deserialize_map(
});

// Visit ignored values to consume them
let ignored_arm = if cattrs.has_flatten() {
let ignored_arm = if has_flatten {
Some(quote! {
__Field::__other(__name) => {
__collect.push(_serde::__private::Some((
Expand Down Expand Up @@ -2602,7 +2630,7 @@ fn deserialize_map(
}
});

let collected_deny_unknown_fields = if cattrs.has_flatten() && cattrs.deny_unknown_fields() {
let collected_deny_unknown_fields = if has_flatten && cattrs.deny_unknown_fields() {
Some(quote! {
if let _serde::__private::Some(_serde::__private::Some((__key, _))) =
__collect.into_iter().filter(_serde::__private::Option::is_some).next()
Expand Down Expand Up @@ -2678,7 +2706,10 @@ fn deserialize_map_in_place(
fields: &[Field],
cattrs: &attr::Container,
) -> Fragment {
assert!(!cattrs.has_flatten());
assert!(
!cattrs.has_flatten(),
"inplace deserialization of maps doesn't support flatten fields"
);

// Create the field names for the fields.
let fields_names: Vec<_> = fields
Expand Down
1 change: 1 addition & 0 deletions serde_derive/src/internals/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl<'a> Container<'a> {
for field in &mut variant.fields {
if field.attrs.flatten() {
has_flatten = true;
variant.attrs.mark_has_flatten();
}
field.attrs.rename_by_rules(
variant
Expand Down
37 changes: 37 additions & 0 deletions serde_derive/src/internals/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,22 @@ pub struct Container {
type_into: Option<syn::Type>,
remote: Option<syn::Path>,
identifier: Identifier,
/// `true` if container is a `struct` and it has a field with `#[serde(flatten)]`
/// attribute or it is an `enum` with a struct variant which has a field with
/// `#[serde(flatten)]` attribute. Examples:
///
/// ```ignore
/// struct Container {
/// #[serde(flatten)]
/// some_field: (),
/// }
/// enum Container {
/// Variant {
/// #[serde(flatten)]
/// some_field: (),
/// },
/// }
/// ```
has_flatten: bool,
serde_path: Option<syn::Path>,
is_packed: bool,
Expand Down Expand Up @@ -794,6 +810,18 @@ pub struct Variant {
rename_all_rules: RenameAllRules,
ser_bound: Option<Vec<syn::WherePredicate>>,
de_bound: Option<Vec<syn::WherePredicate>>,
/// `true` if variant is a struct variant which contains a field with `#[serde(flatten)]`
/// attribute. Examples:
///
/// ```ignore
/// enum Enum {
/// Variant {
/// #[serde(flatten)]
/// some_field: (),
/// },
/// }
/// ```
has_flatten: bool,
skip_deserializing: bool,
skip_serializing: bool,
other: bool,
Expand Down Expand Up @@ -963,6 +991,7 @@ impl Variant {
},
ser_bound: ser_bound.get(),
de_bound: de_bound.get(),
has_flatten: false,
skip_deserializing: skip_deserializing.get(),
skip_serializing: skip_serializing.get(),
other: other.get(),
Expand Down Expand Up @@ -1005,6 +1034,14 @@ impl Variant {
self.de_bound.as_ref().map(|vec| &vec[..])
}

pub fn has_flatten(&self) -> bool {
self.has_flatten
}

pub fn mark_has_flatten(&mut self) {
self.has_flatten = true;
}

pub fn skip_deserializing(&self) -> bool {
self.skip_deserializing
}
Expand Down

0 comments on commit 2b88ec0

Please sign in to comment.