Skip to content

Commit

Permalink
Make #[system_param(ignore)] and #[world_query(ignore)] unnecessa…
Browse files Browse the repository at this point in the history
…ry (bevyengine#8030)

When using `PhantomData` fields with the `#[derive(SystemParam)]` or
`#[derive(WorldQuery)]` macros, the user is required to add the
`#[system_param(ignore)]` attribute so that the macro knows to treat
that field specially. This is undesirable, since it makes the macro more
fragile and less consistent.

Implement `SystemParam` and `WorldQuery` for `PhantomData`. This makes
the `ignore` attributes unnecessary.

Some internal changes make the derive macro compatible with types that
have invariant lifetimes, which fixes bevyengine#8192. From what I can tell, this
fix requires `PhantomData` to implement `SystemParam` in order to ensure
that all of a type's generic parameters are always constrained.

---

+ Implemented `SystemParam` and `WorldQuery` for `PhantomData<T>`.
+ Fixed a miscompilation caused when invariant lifetimes were used with
the `SystemParam` macro.
  • Loading branch information
JoJoJet authored and UkoeHB committed Sep 10, 2023
1 parent 3bec4a7 commit 7400f40
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 59 deletions.
4 changes: 2 additions & 2 deletions crates/bevy_ecs/macros/src/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
#[automatically_derived]
#visibility struct #state_struct_name #user_impl_generics #user_where_clauses {
#(#field_idents: <#field_types as #path::query::WorldQuery>::State,)*
#(#ignored_field_idents: #ignored_field_types,)*
#(#ignored_field_idents: ::std::marker::PhantomData<fn() -> #ignored_field_types>,)*
}

/// SAFETY: we assert fields are readonly below
Expand Down Expand Up @@ -419,7 +419,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
}

struct WorldQueryFieldInfo {
/// Has `#[fetch(ignore)]` or `#[filter_fetch(ignore)]` attribute.
/// Has the `#[world_query(ignore)]` attribute.
is_ignored: bool,
/// All field attributes except for `world_query` ones.
attrs: Vec<Attribute>,
Expand Down
153 changes: 98 additions & 55 deletions crates/bevy_ecs/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@ mod set;
mod states;

use crate::{fetch::derive_world_query_impl, set::derive_set};
use bevy_macro_utils::{
derive_boxed_label, ensure_no_collision, get_named_struct_fields, BevyManifest,
};
use bevy_macro_utils::{derive_boxed_label, get_named_struct_fields, BevyManifest};
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{
parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma,
ConstParam, DeriveInput, GenericParam, Ident, Index, TypeParam,
parse::ParseStream, parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned,
ConstParam, DeriveInput, GenericParam, Ident, Index, Meta, MetaList, NestedMeta, Token,
TypeParam,
};

enum BundleFieldKind {
Expand All @@ -37,23 +36,28 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {

let mut field_kind = Vec::with_capacity(named_fields.len());

for field in named_fields.iter() {
for attr in field
.attrs
.iter()
.filter(|a| a.path().is_ident(BUNDLE_ATTRIBUTE_NAME))
{
if let Err(error) = attr.parse_nested_meta(|meta| {
if meta.path.is_ident(BUNDLE_ATTRIBUTE_IGNORE_NAME) {
field_kind.push(BundleFieldKind::Ignore);
Ok(())
} else {
Err(meta.error(format!(
"Invalid bundle attribute. Use `{BUNDLE_ATTRIBUTE_IGNORE_NAME}`"
)))
'field_loop: for field in named_fields.iter() {
for attr in &field.attrs {
if attr.path.is_ident(BUNDLE_ATTRIBUTE_NAME) {
if let Ok(Meta::List(MetaList { nested, .. })) = attr.parse_meta() {
if let Some(&NestedMeta::Meta(Meta::Path(ref path))) = nested.first() {
if path.is_ident(BUNDLE_ATTRIBUTE_IGNORE_NAME) {
field_kind.push(BundleFieldKind::Ignore);
continue 'field_loop;
}

return syn::Error::new(
path.span(),
format!(
"Invalid bundle attribute. Use `{BUNDLE_ATTRIBUTE_IGNORE_NAME}`"
),
)
.into_compile_error()
.into();
}

return syn::Error::new(attr.span(), format!("Invalid bundle attribute. Use `#[{BUNDLE_ATTRIBUTE_NAME}({BUNDLE_ATTRIBUTE_IGNORE_NAME})]`")).into_compile_error().into();
}
}) {
return error.into_compile_error().into();
}
}

Expand Down Expand Up @@ -122,9 +126,7 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
#(#field_from_components)*
}
}
}

impl #impl_generics #ecs_path::bundle::DynamicBundle for #struct_name #ty_generics #where_clause {
#[allow(unused_variables)]
#[inline]
fn get_components(
Expand Down Expand Up @@ -224,8 +226,8 @@ pub fn impl_param_set(_input: TokenStream) -> TokenStream {
unsafe fn get_param<'w, 's>(
state: &'s mut Self::State,
system_meta: &SystemMeta,
world: UnsafeWorldCell<'w>,
change_tick: Tick,
world: &'w World,
change_tick: u32,
) -> Self::Item<'w, 's> {
ParamSet {
param_states: state,
Expand All @@ -246,10 +248,16 @@ pub fn impl_param_set(_input: TokenStream) -> TokenStream {
tokens
}

#[derive(Default)]
struct SystemParamFieldAttributes {
pub ignore: bool,
}

static SYSTEM_PARAM_ATTRIBUTE_NAME: &str = "system_param";

/// Implement `SystemParam` to use a struct as a parameter in a system
#[proc_macro_derive(SystemParam, attributes(system_param))]
pub fn derive_system_param(input: TokenStream) -> TokenStream {
let token_stream = input.clone();
let ast = parse_macro_input!(input as DeriveInput);
let syn::Data::Struct(syn::DataStruct { fields: field_definitions, .. }) = ast.data else {
return syn::Error::new(ast.span(), "Invalid `SystemParam` type: expected a `struct`")
Expand All @@ -258,20 +266,53 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream {
};
let path = bevy_ecs_path();

let field_attributes = field_definitions
.iter()
.map(|field| {
(
field,
field
.attrs
.iter()
.find(|a| *a.path.get_ident().as_ref().unwrap() == SYSTEM_PARAM_ATTRIBUTE_NAME)
.map_or_else(SystemParamFieldAttributes::default, |a| {
syn::custom_keyword!(ignore);
let mut attributes = SystemParamFieldAttributes::default();
a.parse_args_with(|input: ParseStream| {
if input.parse::<Option<ignore>>()?.is_some() {
attributes.ignore = true;
}
Ok(())
})
.expect("Invalid 'system_param' attribute format.");

attributes
}),
)
})
.collect::<Vec<_>>();

let mut field_locals = Vec::new();
let mut fields = Vec::new();
let mut field_types = Vec::new();
for (i, field) in field_definitions.iter().enumerate() {
field_locals.push(format_ident!("f{i}"));
let i = Index::from(i);
fields.push(
field
.ident
.as_ref()
.map(|f| quote! { #f })
.unwrap_or_else(|| quote! { #i }),
);
field_types.push(&field.ty);
let mut ignored_fields = Vec::new();
let mut ignored_field_types = Vec::new();
for (i, (field, attrs)) in field_attributes.iter().enumerate() {
if attrs.ignore {
ignored_fields.push(field.ident.as_ref().unwrap());
ignored_field_types.push(&field.ty);
} else {
field_locals.push(format_ident!("f{i}"));
let i = Index::from(i);
fields.push(
field
.ident
.as_ref()
.map(|f| quote! { #f })
.unwrap_or_else(|| quote! { #i }),
);
field_types.push(&field.ty);
}
}

let generics = ast.generics;
Expand Down Expand Up @@ -303,7 +344,7 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream {

let shadowed_lifetimes: Vec<_> = generics.lifetimes().map(|_| quote!('_)).collect();

let mut punctuated_generics = Punctuated::<_, Comma>::new();
let mut punctuated_generics = Punctuated::<_, Token![,]>::new();
punctuated_generics.extend(lifetimeless_generics.iter().map(|g| match g {
GenericParam::Type(g) => GenericParam::Type(TypeParam {
default: None,
Expand All @@ -316,14 +357,14 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream {
_ => unreachable!(),
}));

let mut punctuated_generic_idents = Punctuated::<_, Comma>::new();
let mut punctuated_generic_idents = Punctuated::<_, Token![,]>::new();
punctuated_generic_idents.extend(lifetimeless_generics.iter().map(|g| match g {
GenericParam::Type(g) => &g.ident,
GenericParam::Const(g) => &g.ident,
_ => unreachable!(),
}));

let punctuated_generics_no_bounds: Punctuated<_, Comma> = lifetimeless_generics
let punctuated_generics_no_bounds: Punctuated<_, Token![,]> = lifetimeless_generics
.iter()
.map(|&g| match g.clone() {
GenericParam::Type(mut g) => {
Expand All @@ -337,6 +378,13 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream {
let mut tuple_types: Vec<_> = field_types.iter().map(|x| quote! { #x }).collect();
let mut tuple_patterns: Vec<_> = field_locals.iter().map(|x| quote! { #x }).collect();

tuple_types.extend(
ignored_field_types
.iter()
.map(|ty| parse_quote!(::std::marker::PhantomData::<#ty>)),
);
tuple_patterns.extend(ignored_field_types.iter().map(|_| parse_quote!(_)));

// If the number of fields exceeds the 16-parameter limit,
// fold the fields into tuples of tuples until we are below the limit.
const LIMIT: usize = 16;
Expand All @@ -358,12 +406,10 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream {
.push(syn::parse_quote!(#field_type: #path::system::ReadOnlySystemParam));
}

let fields_alias =
ensure_no_collision(format_ident!("__StructFieldsAlias"), token_stream.clone());
let fields_alias = format_ident!("__StructFieldsAlias");

let struct_name = &ast.ident;
let state_struct_visibility = &ast.vis;
let state_struct_name = ensure_no_collision(format_ident!("FetchState"), token_stream);

TokenStream::from(quote! {
// We define the FetchState struct in an anonymous scope to avoid polluting the user namespace.
Expand All @@ -374,19 +420,19 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream {
type #fields_alias <'w, 's, #punctuated_generics_no_bounds> = (#(#tuple_types,)*);

#[doc(hidden)]
#state_struct_visibility struct #state_struct_name <#(#lifetimeless_generics,)*>
#state_struct_visibility struct FetchState <#(#lifetimeless_generics,)*>
#where_clause {
state: <#fields_alias::<'static, 'static, #punctuated_generic_idents> as #path::system::SystemParam>::State,
}

unsafe impl<#punctuated_generics> #path::system::SystemParam for
#struct_name <#(#shadowed_lifetimes,)* #punctuated_generic_idents> #where_clause
{
type State = #state_struct_name<#punctuated_generic_idents>;
type State = FetchState<#punctuated_generic_idents>;
type Item<'w, 's> = #struct_name #ty_generics;

fn init_state(world: &mut #path::world::World, system_meta: &mut #path::system::SystemMeta) -> Self::State {
#state_struct_name {
FetchState {
state: <#fields_alias::<'_, '_, #punctuated_generic_idents> as #path::system::SystemParam>::init_state(world, system_meta),
}
}
Expand All @@ -402,14 +448,15 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream {
unsafe fn get_param<'w, 's>(
state: &'s mut Self::State,
system_meta: &#path::system::SystemMeta,
world: #path::world::unsafe_world_cell::UnsafeWorldCell<'w>,
change_tick: #path::component::Tick,
world: &'w #path::world::World,
change_tick: u32,
) -> Self::Item<'w, 's> {
let (#(#tuple_patterns,)*) = <
(#(#tuple_types,)*) as #path::system::SystemParam
>::get_param(&mut state.state, system_meta, world, change_tick);
#struct_name {
#(#fields: #field_locals,)*
#(#ignored_fields: std::default::Default::default(),)*
}
}
}
Expand All @@ -423,7 +470,8 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream {
/// Implement `WorldQuery` to use a struct as a parameter in a query
#[proc_macro_derive(WorldQuery, attributes(world_query))]
pub fn derive_world_query(input: TokenStream) -> TokenStream {
derive_world_query_impl(input)
let ast = parse_macro_input!(input as DeriveInput);
derive_world_query_impl(ast)
}

/// Derive macro generating an impl of the trait `ScheduleLabel`.
Expand All @@ -439,7 +487,7 @@ pub fn derive_schedule_label(input: TokenStream) -> TokenStream {
}

/// Derive macro generating an impl of the trait `SystemSet`.
#[proc_macro_derive(SystemSet)]
#[proc_macro_derive(SystemSet, attributes(system_set))]
pub fn derive_system_set(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let mut trait_path = bevy_ecs_path();
Expand All @@ -452,11 +500,6 @@ pub(crate) fn bevy_ecs_path() -> syn::Path {
BevyManifest::default().get_path("bevy_ecs")
}

#[proc_macro_derive(Event)]
pub fn derive_event(input: TokenStream) -> TokenStream {
component::derive_event(input)
}

#[proc_macro_derive(Resource)]
pub fn derive_resource(input: TokenStream) -> TokenStream {
component::derive_resource(input)
Expand Down
Loading

0 comments on commit 7400f40

Please sign in to comment.