diff --git a/servo/components/style/values/animated/mod.rs b/servo/components/style/values/animated/mod.rs index 7cac407adcba5..7e699542fd402 100644 --- a/servo/components/style/values/animated/mod.rs +++ b/servo/components/style/values/animated/mod.rs @@ -107,7 +107,7 @@ pub fn animate_multiplicative_factor( /// be equal or an error is returned. /// /// If a variant is annotated with `#[animation(error)]`, the corresponding -/// `match` arm is not generated. +/// `match` arm returns an error. /// /// If the two values are not similar, an error is returned unless a fallback /// function has been specified through `#[animate(fallback)]`. diff --git a/servo/components/style/values/distance.rs b/servo/components/style/values/distance.rs index a1872366c2a45..67c735676b5bf 100644 --- a/servo/components/style/values/distance.rs +++ b/servo/components/style/values/distance.rs @@ -17,7 +17,7 @@ use std::ops::Add; /// on each fields of the values. /// /// If a variant is annotated with `#[animation(error)]`, the corresponding -/// `match` arm is not generated. +/// `match` arm returns an error. /// /// If the two values are not similar, an error is returned unless a fallback /// function has been specified through `#[distance(fallback)]`. diff --git a/servo/components/style_derive/animate.rs b/servo/components/style_derive/animate.rs index 4d8581a8f2bdf..190568514fc63 100644 --- a/servo/components/style_derive/animate.rs +++ b/servo/components/style_derive/animate.rs @@ -11,6 +11,8 @@ use synstructure::{Structure, VariantInfo}; pub fn derive(mut input: DeriveInput) -> TokenStream { let animation_input_attrs = cg::parse_input_attrs::(&input); + let input_attrs = cg::parse_input_attrs::(&input); + let no_bound = animation_input_attrs.no_bound.unwrap_or_default(); let mut where_clause = input.generics.where_clause.take(); for param in input.generics.type_params() { @@ -21,39 +23,32 @@ pub fn derive(mut input: DeriveInput) -> TokenStream { ); } } - let (mut match_body, append_error_clause) = { + let (mut match_body, needs_catchall_branch) = { let s = Structure::new(&input); - let mut append_error_clause = s.variants().len() > 1; - + let needs_catchall_branch = s.variants().len() > 1; let match_body = s.variants().iter().fold(quote!(), |body, variant| { - let arm = match derive_variant_arm(variant, &mut where_clause) { - Ok(arm) => arm, - Err(()) => { - append_error_clause = true; - return body; - }, - }; + let arm = derive_variant_arm(variant, &mut where_clause); quote! { #body #arm } }); - (match_body, append_error_clause) + (match_body, needs_catchall_branch) }; input.generics.where_clause = where_clause; - if append_error_clause { - let input_attrs = cg::parse_input_attrs::(&input); - if let Some(fallback) = input_attrs.fallback { - match_body.append_all(quote! { - (this, other) => #fallback(this, other, procedure) - }); - } else { - match_body.append_all(quote! { _ => Err(()) }); - } + if needs_catchall_branch { + // This ideally shouldn't be needed, but see + // https://github.com/rust-lang/rust/issues/68867 + match_body.append_all(quote! { _ => unsafe { debug_unreachable!() } }); } let name = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let fallback = match input_attrs.fallback { + Some(fallback) => quote! { #fallback(self, other, procedure) }, + None => quote! { Err(()) }, + }; + quote! { impl #impl_generics crate::values::animated::Animate for #name #ty_generics #where_clause { #[allow(unused_variables, unused_imports)] @@ -63,6 +58,9 @@ pub fn derive(mut input: DeriveInput) -> TokenStream { other: &Self, procedure: crate::values::animated::Procedure, ) -> Result { + if std::mem::discriminant(self) != std::mem::discriminant(other) { + return #fallback; + } match (self, other) { #match_body } @@ -74,13 +72,17 @@ pub fn derive(mut input: DeriveInput) -> TokenStream { fn derive_variant_arm( variant: &VariantInfo, where_clause: &mut Option, -) -> Result { +) -> TokenStream { let variant_attrs = cg::parse_variant_attrs_from_ast::(&variant.ast()); - if variant_attrs.error { - return Err(()); - } let (this_pattern, this_info) = cg::ref_pattern(&variant, "this"); let (other_pattern, other_info) = cg::ref_pattern(&variant, "other"); + + if variant_attrs.error { + return quote! { + (&#this_pattern, &#other_pattern) => Err(()), + }; + } + let (result_value, result_info) = cg::value(&variant, "result"); let mut computations = quote!(); let iter = result_info.iter().zip(this_info.iter().zip(&other_info)); @@ -107,12 +109,13 @@ fn derive_variant_arm( } } })); - Ok(quote! { + + quote! { (&#this_pattern, &#other_pattern) => { #computations Ok(#result_value) } - }) + } } #[darling(attributes(animate), default)] diff --git a/servo/components/style_derive/compute_squared_distance.rs b/servo/components/style_derive/compute_squared_distance.rs index 9c5f7ec80d1b9..5e130e75b065c 100644 --- a/servo/components/style_derive/compute_squared_distance.rs +++ b/servo/components/style_derive/compute_squared_distance.rs @@ -6,11 +6,12 @@ use crate::animate::{AnimationFieldAttrs, AnimationInputAttrs, AnimationVariantA use derive_common::cg; use proc_macro2::TokenStream; use quote::TokenStreamExt; -use syn::{DeriveInput, Path}; +use syn::{DeriveInput, Path, WhereClause}; use synstructure; pub fn derive(mut input: DeriveInput) -> TokenStream { let animation_input_attrs = cg::parse_input_attrs::(&input); + let input_attrs = cg::parse_input_attrs::(&input); let no_bound = animation_input_attrs.no_bound.unwrap_or_default(); let mut where_clause = input.generics.where_clause.take(); for param in input.generics.type_params() { @@ -22,76 +23,31 @@ pub fn derive(mut input: DeriveInput) -> TokenStream { } } - let (mut match_body, append_error_clause) = { + let (mut match_body, needs_catchall_branch) = { let s = synstructure::Structure::new(&input); - let mut append_error_clause = s.variants().len() > 1; + let needs_catchall_branch = s.variants().len() > 1; let match_body = s.variants().iter().fold(quote!(), |body, variant| { - let attrs = cg::parse_variant_attrs_from_ast::(&variant.ast()); - if attrs.error { - append_error_clause = true; - return body; - } - - let (this_pattern, this_info) = cg::ref_pattern(&variant, "this"); - let (other_pattern, other_info) = cg::ref_pattern(&variant, "other"); - let sum = if this_info.is_empty() { - quote! { crate::values::distance::SquaredDistance::from_sqrt(0.) } - } else { - let mut sum = quote!(); - sum.append_separated(this_info.iter().zip(&other_info).map(|(this, other)| { - let field_attrs = cg::parse_field_attrs::(&this.ast()); - if field_attrs.field_bound { - let ty = &this.ast().ty; - cg::add_predicate( - &mut where_clause, - parse_quote!(#ty: crate::values::distance::ComputeSquaredDistance), - ); - } - - let animation_field_attrs = - cg::parse_field_attrs::(&this.ast()); - - if animation_field_attrs.constant { - quote! { - { - if #this != #other { - return Err(()); - } - crate::values::distance::SquaredDistance::from_sqrt(0.) - } - } - } else { - quote! { - crate::values::distance::ComputeSquaredDistance::compute_squared_distance(#this, #other)? - } - } - }), quote!(+)); - sum - }; - quote! { - #body - (&#this_pattern, &#other_pattern) => { - Ok(#sum) - } - } + let arm = derive_variant_arm(variant, &mut where_clause); + quote! { #body #arm } }); - (match_body, append_error_clause) + (match_body, needs_catchall_branch) }; + input.generics.where_clause = where_clause; - if append_error_clause { - let input_attrs = cg::parse_input_attrs::(&input); - if let Some(fallback) = input_attrs.fallback { - match_body.append_all(quote! { - (this, other) => #fallback(this, other) - }); - } else { - match_body.append_all(quote! { _ => Err(()) }); - } + if needs_catchall_branch { + // This ideally shouldn't be needed, but see: + // https://github.com/rust-lang/rust/issues/68867 + match_body.append_all(quote! { _ => unsafe { debug_unreachable!() } }); } + let fallback = match input_attrs.fallback { + Some(fallback) => quote! { #fallback(self, other) }, + None => quote! { Err(()) }, + }; + let name = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); @@ -103,6 +59,9 @@ pub fn derive(mut input: DeriveInput) -> TokenStream { &self, other: &Self, ) -> Result { + if std::mem::discriminant(self) != std::mem::discriminant(other) { + return #fallback; + } match (self, other) { #match_body } @@ -111,6 +70,60 @@ pub fn derive(mut input: DeriveInput) -> TokenStream { } } +fn derive_variant_arm( + variant: &synstructure::VariantInfo, + mut where_clause: &mut Option, +) -> TokenStream { + let variant_attrs = cg::parse_variant_attrs_from_ast::(&variant.ast()); + let (this_pattern, this_info) = cg::ref_pattern(&variant, "this"); + let (other_pattern, other_info) = cg::ref_pattern(&variant, "other"); + + if variant_attrs.error { + return quote! { + (&#this_pattern, &#other_pattern) => Err(()), + }; + } + + let sum = if this_info.is_empty() { + quote! { crate::values::distance::SquaredDistance::from_sqrt(0.) } + } else { + let mut sum = quote!(); + sum.append_separated(this_info.iter().zip(&other_info).map(|(this, other)| { + let field_attrs = cg::parse_field_attrs::(&this.ast()); + if field_attrs.field_bound { + let ty = &this.ast().ty; + cg::add_predicate( + &mut where_clause, + parse_quote!(#ty: crate::values::distance::ComputeSquaredDistance), + ); + } + + let animation_field_attrs = + cg::parse_field_attrs::(&this.ast()); + + if animation_field_attrs.constant { + quote! { + { + if #this != #other { + return Err(()); + } + crate::values::distance::SquaredDistance::from_sqrt(0.) + } + } + } else { + quote! { + crate::values::distance::ComputeSquaredDistance::compute_squared_distance(#this, #other)? + } + } + }), quote!(+)); + sum + }; + + return quote! { + (&#this_pattern, &#other_pattern) => Ok(#sum), + }; +} + #[darling(attributes(distance), default)] #[derive(Default, FromDeriveInput)] struct DistanceInputAttrs {