From 7400f4070e9d8a40337e16e1e711c29786d732f2 Mon Sep 17 00:00:00 2001 From: JoJoJet <21144246+JoJoJet@users.noreply.github.com> Date: Thu, 30 Mar 2023 11:12:26 -0400 Subject: [PATCH] Make `#[system_param(ignore)]` and `#[world_query(ignore)]` unnecessary (#8030) When using `PhantomData` fields with the `#[derive(SystemParam)]` or `#[derive(WorldQuery)]` macros, the user is required to add the `#[system_param(ignore)]` attribute so that the macro knows to treat that field specially. This is undesirable, since it makes the macro more fragile and less consistent. Implement `SystemParam` and `WorldQuery` for `PhantomData`. This makes the `ignore` attributes unnecessary. Some internal changes make the derive macro compatible with types that have invariant lifetimes, which fixes #8192. From what I can tell, this fix requires `PhantomData` to implement `SystemParam` in order to ensure that all of a type's generic parameters are always constrained. --- + Implemented `SystemParam` and `WorldQuery` for `PhantomData`. + Fixed a miscompilation caused when invariant lifetimes were used with the `SystemParam` macro. --- crates/bevy_ecs/macros/src/fetch.rs | 4 +- crates/bevy_ecs/macros/src/lib.rs | 153 +++++++++++++-------- crates/bevy_ecs/src/query/fetch.rs | 100 +++++++++++++- crates/bevy_ecs/src/system/system_param.rs | 34 ++++- 4 files changed, 232 insertions(+), 59 deletions(-) diff --git a/crates/bevy_ecs/macros/src/fetch.rs b/crates/bevy_ecs/macros/src/fetch.rs index 03cfba5179c831..e1d276efec33c7 100644 --- a/crates/bevy_ecs/macros/src/fetch.rs +++ b/crates/bevy_ecs/macros/src/fetch.rs @@ -378,7 +378,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream { #[automatically_derived] #visibility struct #state_struct_name #user_impl_generics #user_where_clauses { #(#field_idents: <#field_types as #path::query::WorldQuery>::State,)* - #(#ignored_field_idents: #ignored_field_types,)* + #(#ignored_field_idents: ::std::marker::PhantomData #ignored_field_types>,)* } /// SAFETY: we assert fields are readonly below @@ -419,7 +419,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream { } struct WorldQueryFieldInfo { - /// Has `#[fetch(ignore)]` or `#[filter_fetch(ignore)]` attribute. + /// Has the `#[world_query(ignore)]` attribute. is_ignored: bool, /// All field attributes except for `world_query` ones. attrs: Vec, diff --git a/crates/bevy_ecs/macros/src/lib.rs b/crates/bevy_ecs/macros/src/lib.rs index 110df682609812..f1bb1106809d7a 100644 --- a/crates/bevy_ecs/macros/src/lib.rs +++ b/crates/bevy_ecs/macros/src/lib.rs @@ -6,15 +6,14 @@ mod set; mod states; use crate::{fetch::derive_world_query_impl, set::derive_set}; -use bevy_macro_utils::{ - derive_boxed_label, ensure_no_collision, get_named_struct_fields, BevyManifest, -}; +use bevy_macro_utils::{derive_boxed_label, get_named_struct_fields, BevyManifest}; use proc_macro::TokenStream; use proc_macro2::Span; use quote::{format_ident, quote}; use syn::{ - parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, - ConstParam, DeriveInput, GenericParam, Ident, Index, TypeParam, + parse::ParseStream, parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, + ConstParam, DeriveInput, GenericParam, Ident, Index, Meta, MetaList, NestedMeta, Token, + TypeParam, }; enum BundleFieldKind { @@ -37,23 +36,28 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream { let mut field_kind = Vec::with_capacity(named_fields.len()); - for field in named_fields.iter() { - for attr in field - .attrs - .iter() - .filter(|a| a.path().is_ident(BUNDLE_ATTRIBUTE_NAME)) - { - if let Err(error) = attr.parse_nested_meta(|meta| { - if meta.path.is_ident(BUNDLE_ATTRIBUTE_IGNORE_NAME) { - field_kind.push(BundleFieldKind::Ignore); - Ok(()) - } else { - Err(meta.error(format!( - "Invalid bundle attribute. Use `{BUNDLE_ATTRIBUTE_IGNORE_NAME}`" - ))) + 'field_loop: for field in named_fields.iter() { + for attr in &field.attrs { + if attr.path.is_ident(BUNDLE_ATTRIBUTE_NAME) { + if let Ok(Meta::List(MetaList { nested, .. })) = attr.parse_meta() { + if let Some(&NestedMeta::Meta(Meta::Path(ref path))) = nested.first() { + if path.is_ident(BUNDLE_ATTRIBUTE_IGNORE_NAME) { + field_kind.push(BundleFieldKind::Ignore); + continue 'field_loop; + } + + return syn::Error::new( + path.span(), + format!( + "Invalid bundle attribute. Use `{BUNDLE_ATTRIBUTE_IGNORE_NAME}`" + ), + ) + .into_compile_error() + .into(); + } + + return syn::Error::new(attr.span(), format!("Invalid bundle attribute. Use `#[{BUNDLE_ATTRIBUTE_NAME}({BUNDLE_ATTRIBUTE_IGNORE_NAME})]`")).into_compile_error().into(); } - }) { - return error.into_compile_error().into(); } } @@ -122,9 +126,7 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream { #(#field_from_components)* } } - } - impl #impl_generics #ecs_path::bundle::DynamicBundle for #struct_name #ty_generics #where_clause { #[allow(unused_variables)] #[inline] fn get_components( @@ -224,8 +226,8 @@ pub fn impl_param_set(_input: TokenStream) -> TokenStream { unsafe fn get_param<'w, 's>( state: &'s mut Self::State, system_meta: &SystemMeta, - world: UnsafeWorldCell<'w>, - change_tick: Tick, + world: &'w World, + change_tick: u32, ) -> Self::Item<'w, 's> { ParamSet { param_states: state, @@ -246,10 +248,16 @@ pub fn impl_param_set(_input: TokenStream) -> TokenStream { tokens } +#[derive(Default)] +struct SystemParamFieldAttributes { + pub ignore: bool, +} + +static SYSTEM_PARAM_ATTRIBUTE_NAME: &str = "system_param"; + /// Implement `SystemParam` to use a struct as a parameter in a system #[proc_macro_derive(SystemParam, attributes(system_param))] pub fn derive_system_param(input: TokenStream) -> TokenStream { - let token_stream = input.clone(); let ast = parse_macro_input!(input as DeriveInput); let syn::Data::Struct(syn::DataStruct { fields: field_definitions, .. }) = ast.data else { return syn::Error::new(ast.span(), "Invalid `SystemParam` type: expected a `struct`") @@ -258,20 +266,53 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream { }; let path = bevy_ecs_path(); + let field_attributes = field_definitions + .iter() + .map(|field| { + ( + field, + field + .attrs + .iter() + .find(|a| *a.path.get_ident().as_ref().unwrap() == SYSTEM_PARAM_ATTRIBUTE_NAME) + .map_or_else(SystemParamFieldAttributes::default, |a| { + syn::custom_keyword!(ignore); + let mut attributes = SystemParamFieldAttributes::default(); + a.parse_args_with(|input: ParseStream| { + if input.parse::>()?.is_some() { + attributes.ignore = true; + } + Ok(()) + }) + .expect("Invalid 'system_param' attribute format."); + + attributes + }), + ) + }) + .collect::>(); + let mut field_locals = Vec::new(); let mut fields = Vec::new(); let mut field_types = Vec::new(); - for (i, field) in field_definitions.iter().enumerate() { - field_locals.push(format_ident!("f{i}")); - let i = Index::from(i); - fields.push( - field - .ident - .as_ref() - .map(|f| quote! { #f }) - .unwrap_or_else(|| quote! { #i }), - ); - field_types.push(&field.ty); + let mut ignored_fields = Vec::new(); + let mut ignored_field_types = Vec::new(); + for (i, (field, attrs)) in field_attributes.iter().enumerate() { + if attrs.ignore { + ignored_fields.push(field.ident.as_ref().unwrap()); + ignored_field_types.push(&field.ty); + } else { + field_locals.push(format_ident!("f{i}")); + let i = Index::from(i); + fields.push( + field + .ident + .as_ref() + .map(|f| quote! { #f }) + .unwrap_or_else(|| quote! { #i }), + ); + field_types.push(&field.ty); + } } let generics = ast.generics; @@ -303,7 +344,7 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream { let shadowed_lifetimes: Vec<_> = generics.lifetimes().map(|_| quote!('_)).collect(); - let mut punctuated_generics = Punctuated::<_, Comma>::new(); + let mut punctuated_generics = Punctuated::<_, Token![,]>::new(); punctuated_generics.extend(lifetimeless_generics.iter().map(|g| match g { GenericParam::Type(g) => GenericParam::Type(TypeParam { default: None, @@ -316,14 +357,14 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream { _ => unreachable!(), })); - let mut punctuated_generic_idents = Punctuated::<_, Comma>::new(); + let mut punctuated_generic_idents = Punctuated::<_, Token![,]>::new(); punctuated_generic_idents.extend(lifetimeless_generics.iter().map(|g| match g { GenericParam::Type(g) => &g.ident, GenericParam::Const(g) => &g.ident, _ => unreachable!(), })); - let punctuated_generics_no_bounds: Punctuated<_, Comma> = lifetimeless_generics + let punctuated_generics_no_bounds: Punctuated<_, Token![,]> = lifetimeless_generics .iter() .map(|&g| match g.clone() { GenericParam::Type(mut g) => { @@ -337,6 +378,13 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream { let mut tuple_types: Vec<_> = field_types.iter().map(|x| quote! { #x }).collect(); let mut tuple_patterns: Vec<_> = field_locals.iter().map(|x| quote! { #x }).collect(); + tuple_types.extend( + ignored_field_types + .iter() + .map(|ty| parse_quote!(::std::marker::PhantomData::<#ty>)), + ); + tuple_patterns.extend(ignored_field_types.iter().map(|_| parse_quote!(_))); + // If the number of fields exceeds the 16-parameter limit, // fold the fields into tuples of tuples until we are below the limit. const LIMIT: usize = 16; @@ -358,12 +406,10 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream { .push(syn::parse_quote!(#field_type: #path::system::ReadOnlySystemParam)); } - let fields_alias = - ensure_no_collision(format_ident!("__StructFieldsAlias"), token_stream.clone()); + let fields_alias = format_ident!("__StructFieldsAlias"); let struct_name = &ast.ident; let state_struct_visibility = &ast.vis; - let state_struct_name = ensure_no_collision(format_ident!("FetchState"), token_stream); TokenStream::from(quote! { // We define the FetchState struct in an anonymous scope to avoid polluting the user namespace. @@ -374,7 +420,7 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream { type #fields_alias <'w, 's, #punctuated_generics_no_bounds> = (#(#tuple_types,)*); #[doc(hidden)] - #state_struct_visibility struct #state_struct_name <#(#lifetimeless_generics,)*> + #state_struct_visibility struct FetchState <#(#lifetimeless_generics,)*> #where_clause { state: <#fields_alias::<'static, 'static, #punctuated_generic_idents> as #path::system::SystemParam>::State, } @@ -382,11 +428,11 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream { unsafe impl<#punctuated_generics> #path::system::SystemParam for #struct_name <#(#shadowed_lifetimes,)* #punctuated_generic_idents> #where_clause { - type State = #state_struct_name<#punctuated_generic_idents>; + type State = FetchState<#punctuated_generic_idents>; type Item<'w, 's> = #struct_name #ty_generics; fn init_state(world: &mut #path::world::World, system_meta: &mut #path::system::SystemMeta) -> Self::State { - #state_struct_name { + FetchState { state: <#fields_alias::<'_, '_, #punctuated_generic_idents> as #path::system::SystemParam>::init_state(world, system_meta), } } @@ -402,14 +448,15 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream { unsafe fn get_param<'w, 's>( state: &'s mut Self::State, system_meta: &#path::system::SystemMeta, - world: #path::world::unsafe_world_cell::UnsafeWorldCell<'w>, - change_tick: #path::component::Tick, + world: &'w #path::world::World, + change_tick: u32, ) -> Self::Item<'w, 's> { let (#(#tuple_patterns,)*) = < (#(#tuple_types,)*) as #path::system::SystemParam >::get_param(&mut state.state, system_meta, world, change_tick); #struct_name { #(#fields: #field_locals,)* + #(#ignored_fields: std::default::Default::default(),)* } } } @@ -423,7 +470,8 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream { /// Implement `WorldQuery` to use a struct as a parameter in a query #[proc_macro_derive(WorldQuery, attributes(world_query))] pub fn derive_world_query(input: TokenStream) -> TokenStream { - derive_world_query_impl(input) + let ast = parse_macro_input!(input as DeriveInput); + derive_world_query_impl(ast) } /// Derive macro generating an impl of the trait `ScheduleLabel`. @@ -439,7 +487,7 @@ pub fn derive_schedule_label(input: TokenStream) -> TokenStream { } /// Derive macro generating an impl of the trait `SystemSet`. -#[proc_macro_derive(SystemSet)] +#[proc_macro_derive(SystemSet, attributes(system_set))] pub fn derive_system_set(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let mut trait_path = bevy_ecs_path(); @@ -452,11 +500,6 @@ pub(crate) fn bevy_ecs_path() -> syn::Path { BevyManifest::default().get_path("bevy_ecs") } -#[proc_macro_derive(Event)] -pub fn derive_event(input: TokenStream) -> TokenStream { - component::derive_event(input) -} - #[proc_macro_derive(Resource)] pub fn derive_resource(input: TokenStream) -> TokenStream { component::derive_resource(input) diff --git a/crates/bevy_ecs/src/query/fetch.rs b/crates/bevy_ecs/src/query/fetch.rs index 02229420fa9944..9e6a9291b1e3f7 100644 --- a/crates/bevy_ecs/src/query/fetch.rs +++ b/crates/bevy_ecs/src/query/fetch.rs @@ -281,6 +281,24 @@ use std::{cell::UnsafeCell, marker::PhantomData}; /// # bevy_ecs::system::assert_is_system(my_system); /// ``` /// +/// # Generic Queries +/// +/// When writing generic code, it is often necessary to use [`PhantomData`] +/// to constrain type parameters. Since `WorldQuery` is implemented for all +/// `PhantomData` types, this pattern can be used with this macro. +/// +/// ``` +/// # use bevy_ecs::{prelude::*, query::WorldQuery}; +/// # use std::marker::PhantomData; +/// #[derive(WorldQuery)] +/// pub struct GenericQuery { +/// id: Entity, +/// marker: PhantomData, +/// } +/// # fn my_system(q: Query>) {} +/// # bevy_ecs::system::assert_is_system(my_system); +/// ``` +/// /// # Safety /// /// Component access of `Self::ReadOnly` must be a subset of `Self` @@ -1563,7 +1581,6 @@ macro_rules! impl_anytuple_fetch { /// SAFETY: each item in the tuple is read only unsafe impl<$($name: ReadOnlyWorldQuery),*> ReadOnlyWorldQuery for AnyOf<($($name,)*)> {} - }; } @@ -1643,6 +1660,71 @@ unsafe impl WorldQuery for NopWorldQuery { /// SAFETY: `NopFetch` never accesses any data unsafe impl ReadOnlyWorldQuery for NopWorldQuery {} +/// SAFETY: `PhantomData` never accesses any world data. +unsafe impl WorldQuery for PhantomData { + type Item<'a> = (); + type Fetch<'a> = (); + type ReadOnly = Self; + type State = (); + + fn shrink<'wlong: 'wshort, 'wshort>(_item: Self::Item<'wlong>) -> Self::Item<'wshort> {} + + unsafe fn init_fetch<'w>( + _world: &'w World, + _state: &Self::State, + _last_change_tick: u32, + _change_tick: u32, + ) -> Self::Fetch<'w> { + } + + unsafe fn clone_fetch<'w>(_fetch: &Self::Fetch<'w>) -> Self::Fetch<'w> {} + + // `PhantomData` does not match any components, so all components it matches + // are stored in a Table (vacuous truth). + const IS_DENSE: bool = true; + // `PhantomData` matches every entity in each archetype. + const IS_ARCHETYPAL: bool = true; + + unsafe fn set_archetype<'w>( + _fetch: &mut Self::Fetch<'w>, + _state: &Self::State, + _archetype: &'w Archetype, + _table: &'w Table, + ) { + } + + unsafe fn set_table<'w>(_fetch: &mut Self::Fetch<'w>, _state: &Self::State, _table: &'w Table) { + } + + unsafe fn fetch<'w>( + _fetch: &mut Self::Fetch<'w>, + _entity: Entity, + _table_row: TableRow, + ) -> Self::Item<'w> { + } + + fn update_component_access(_state: &Self::State, _access: &mut FilteredAccess) {} + + fn update_archetype_component_access( + _state: &Self::State, + _archetype: &Archetype, + _access: &mut Access, + ) { + } + + fn init_state(_world: &mut World) -> Self::State {} + + fn matches_component_set( + _state: &Self::State, + _set_contains_id: &impl Fn(ComponentId) -> bool, + ) -> bool { + true + } +} + +/// SAFETY: `PhantomData` never accesses any world data. +unsafe impl ReadOnlyWorldQuery for PhantomData {} + #[cfg(test)] mod tests { use super::*; @@ -1651,6 +1733,22 @@ mod tests { #[derive(Component)] pub struct A; + // Compile test for https://github.com/bevyengine/bevy/pull/8030. + #[test] + fn world_query_phantom_data() { + #[derive(WorldQuery)] + pub struct IgnoredQuery { + id: Entity, + #[world_query(ignore)] + _marker: PhantomData, + _marker2: PhantomData, + } + + fn ignored_system(_: Query>) {} + + crate::system::assert_is_system(ignored_system); + } + // Ensures that each field of a `WorldQuery` struct's read-only variant // has the same visibility as its corresponding mutable field. #[test] diff --git a/crates/bevy_ecs/src/system/system_param.rs b/crates/bevy_ecs/src/system/system_param.rs index 9e14e6a3a2d049..258687ad9f0329 100644 --- a/crates/bevy_ecs/src/system/system_param.rs +++ b/crates/bevy_ecs/src/system/system_param.rs @@ -19,6 +19,7 @@ use bevy_utils::{all_tuples, synccell::SyncCell}; use std::{ borrow::Cow, fmt::Debug, + marker::PhantomData, ops::{Deref, DerefMut}, }; @@ -65,6 +66,11 @@ use std::{ /// # bevy_ecs::system::assert_is_system(my_system::<()>); /// ``` /// +/// ## `PhantomData` +/// +/// [`PhantomData`] is a special type of `SystemParam` that does nothing. +/// This is useful for constraining generic types or lifetimes. +/// /// # Generic `SystemParam`s /// /// When using the derive macro, you may see an error in the form of: @@ -1476,7 +1482,6 @@ pub mod lifetimeless { /// #[derive(SystemParam)] /// struct GenericParam<'w, 's, T: SystemParam> { /// field: T, -/// #[system_param(ignore)] /// // Use the lifetimes in this type, or they will be unbound. /// phantom: core::marker::PhantomData<&'w &'s ()> /// } @@ -1542,6 +1547,26 @@ unsafe impl SystemParam for StaticSystemParam<'_, '_, } } +// SAFETY: No world access. +unsafe impl SystemParam for PhantomData { + type State = (); + type Item<'world, 'state> = Self; + + fn init_state(_world: &mut World, _system_meta: &mut SystemMeta) -> Self::State {} + + unsafe fn get_param<'world, 'state>( + _state: &'state mut Self::State, + _system_meta: &SystemMeta, + _world: &'world World, + _change_tick: u32, + ) -> Self::Item<'world, 'state> { + PhantomData + } +} + +// SAFETY: No world access. +unsafe impl ReadOnlySystemParam for PhantomData {} + #[cfg(test)] mod tests { use super::*; @@ -1616,6 +1641,7 @@ mod tests { _foo: Res<'w, T>, #[system_param(ignore)] marker: PhantomData<&'w Marker>, + marker2: PhantomData<&'w Marker>, } // Compile tests for https://github.com/bevyengine/bevy/pull/6957. @@ -1644,4 +1670,10 @@ mod tests { { _q: Query<'w, 's, Q, ()>, } + + // Regression test for https://github.com/bevyengine/bevy/issues/8192. + #[derive(SystemParam)] + pub struct InvariantParam<'w, 's> { + _set: ParamSet<'w, 's, (Query<'w, 's, ()>,)>, + } }