Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend the WorldQuery macro to tuple structs #8119

Merged
merged 12 commits into from
Apr 4, 2023
122 changes: 82 additions & 40 deletions crates/bevy_ecs/macros/src/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};
use quote::{format_ident, quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_quote,
punctuated::Punctuated,
Attribute, Data, DataStruct, DeriveInput, Field, Fields,
Attribute, Data, DataStruct, DeriveInput, Field, Index,
};

use crate::bevy_ecs_path;
Expand Down Expand Up @@ -112,37 +112,59 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {

let state_struct_name = Ident::new(&format!("{struct_name}State"), Span::call_site());

let fields = match &ast.data {
Data::Struct(DataStruct {
fields: Fields::Named(fields),
..
}) => &fields.named,
_ => panic!("Expected a struct with named fields"),
let Data::Struct(DataStruct { fields, .. }) = &ast.data else {
return syn::Error::new(
Span::call_site(),
"#[derive(WorldQuery)]` only supports structs",
)
.into_compile_error()
.into()
};
if fields.is_empty() {
return syn::Error::new(
Span::call_site(),
"#[derive(WorldQuery)]` does not support fieldless structs",
JoJoJet marked this conversation as resolved.
Show resolved Hide resolved
)
.into_compile_error()
.into();
}

let mut ignored_field_attrs = Vec::new();
let mut ignored_field_visibilities = Vec::new();
let mut ignored_field_idents = Vec::new();
let mut ignored_named_field_idents = Vec::new();
let mut ignored_field_types = Vec::new();
let mut field_attrs = Vec::new();
let mut field_visibilities = Vec::new();
let mut field_idents = Vec::new();
let mut named_field_idents = Vec::new();
let mut field_types = Vec::new();
let mut read_only_field_types = Vec::new();

for field in fields {
for (i, field) in fields.iter().enumerate() {
let WorldQueryFieldInfo { is_ignored, attrs } = read_world_query_field_info(field);

let field_ident = field.ident.as_ref().unwrap().clone();
let named_field_ident = field
.ident
.as_ref()
.cloned()
.unwrap_or_else(|| format_ident!("f{i}"));
let i = Index::from(i);
let field_ident = field
.ident
.as_ref()
.map_or(quote! { #i }, |i| quote! { #i });
if is_ignored {
ignored_field_attrs.push(attrs);
ignored_field_visibilities.push(field.vis.clone());
ignored_field_idents.push(field_ident.clone());
ignored_field_idents.push(field_ident);
ignored_named_field_idents.push(named_field_ident);
ignored_field_types.push(field.ty.clone());
} else {
field_attrs.push(attrs);
field_visibilities.push(field.vis.clone());
field_idents.push(field_ident.clone());
field_idents.push(field_ident);
named_field_idents.push(named_field_ident);
let field_ty = field.ty.clone();
field_types.push(quote!(#field_ty));
read_only_field_types.push(quote!(<#field_ty as #path::query::WorldQuery>::ReadOnly));
Expand Down Expand Up @@ -176,16 +198,36 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
&field_types
};

let item_struct = quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#(#field_attrs)* #field_visibilities #field_idents: <#field_types as #path::query::WorldQuery>::Item<'__w>,)*
#(#(#ignored_field_attrs)* #ignored_field_visibilities #ignored_field_idents: #ignored_field_types,)*
}
let item_struct = match fields {
syn::Fields::Named(_) => quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#(#field_attrs)* #field_visibilities #field_idents: <#field_types as #path::query::WorldQuery>::Item<'__w>,)*
#(#(#ignored_field_attrs)* #ignored_field_visibilities #ignored_field_idents: #ignored_field_types,)*
}
},
syn::Fields::Unnamed(_) => quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world(
#( #field_visibilities <#field_types as #path::query::WorldQuery>::Item<'__w>, )*
);
},
syn::Fields::Unit => quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world;
},
};

let query_impl = quote! {
Expand All @@ -195,8 +237,8 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
#[doc = "`], used to define the world data accessed by this query."]
#[automatically_derived]
#visibility struct #fetch_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#field_idents: <#field_types as #path::query::WorldQuery>::Fetch<'__w>,)*
#(#ignored_field_idents: #ignored_field_types,)*
#(#named_field_idents: <#field_types as #path::query::WorldQuery>::Fetch<'__w>,)*
#(#ignored_named_field_idents: #ignored_field_types,)*
}

// SAFETY: `update_component_access` and `update_archetype_component_access` are called on every field
Expand Down Expand Up @@ -228,15 +270,15 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_this_run: #path::component::Tick,
) -> <Self as #path::query::WorldQuery>::Fetch<'__w> {
#fetch_struct_name {
#(#field_idents:
#(#named_field_idents:
<#field_types>::init_fetch(
_world,
&state.#field_idents,
&state.#named_field_idents,
_last_run,
_this_run,
),
)*
#(#ignored_field_idents: Default::default(),)*
#(#ignored_named_field_idents: Default::default(),)*
}
}

Expand All @@ -245,10 +287,10 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
) -> <Self as #path::query::WorldQuery>::Fetch<'__w> {
#fetch_struct_name {
#(
#field_idents: <#field_types>::clone_fetch(& _fetch. #field_idents),
#named_field_idents: <#field_types>::clone_fetch(& _fetch. #named_field_idents),
)*
#(
#ignored_field_idents: Default::default(),
#ignored_named_field_idents: Default::default(),
)*
}
}
Expand All @@ -265,7 +307,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_archetype: &'__w #path::archetype::Archetype,
_table: &'__w #path::storage::Table
) {
#(<#field_types>::set_archetype(&mut _fetch.#field_idents, &_state.#field_idents, _archetype, _table);)*
#(<#field_types>::set_archetype(&mut _fetch.#named_field_idents, &_state.#named_field_idents, _archetype, _table);)*
}

/// SAFETY: we call `set_table` for each member that implements `Fetch`
Expand All @@ -275,7 +317,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_state: &Self::State,
_table: &'__w #path::storage::Table
) {
#(<#field_types>::set_table(&mut _fetch.#field_idents, &_state.#field_idents, _table);)*
#(<#field_types>::set_table(&mut _fetch.#named_field_idents, &_state.#named_field_idents, _table);)*
}

/// SAFETY: we call `fetch` for each member that implements `Fetch`.
Expand All @@ -286,7 +328,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_table_row: #path::storage::TableRow,
) -> <Self as #path::query::WorldQuery>::Item<'__w> {
Self::Item {
#(#field_idents: <#field_types>::fetch(&mut _fetch.#field_idents, _entity, _table_row),)*
#(#field_idents: <#field_types>::fetch(&mut _fetch.#named_field_idents, _entity, _table_row),)*
#(#ignored_field_idents: Default::default(),)*
}
}
Expand All @@ -298,11 +340,11 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_entity: #path::entity::Entity,
_table_row: #path::storage::TableRow,
) -> bool {
true #(&& <#field_types>::filter_fetch(&mut _fetch.#field_idents, _entity, _table_row))*
true #(&& <#field_types>::filter_fetch(&mut _fetch.#named_field_idents, _entity, _table_row))*
}

fn update_component_access(state: &Self::State, _access: &mut #path::query::FilteredAccess<#path::component::ComponentId>) {
#( <#field_types>::update_component_access(&state.#field_idents, _access); )*
#( <#field_types>::update_component_access(&state.#named_field_idents, _access); )*
}

fn update_archetype_component_access(
Expand All @@ -311,19 +353,19 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_access: &mut #path::query::Access<#path::archetype::ArchetypeComponentId>
) {
#(
<#field_types>::update_archetype_component_access(&state.#field_idents, _archetype, _access);
<#field_types>::update_archetype_component_access(&state.#named_field_idents, _archetype, _access);
)*
}

fn init_state(world: &mut #path::world::World) -> #state_struct_name #user_ty_generics {
#state_struct_name {
#(#field_idents: <#field_types>::init_state(world),)*
#(#named_field_idents: <#field_types>::init_state(world),)*
#(#ignored_field_idents: Default::default(),)*
}
}

fn matches_component_set(state: &Self::State, _set_contains_id: &impl Fn(#path::component::ComponentId) -> bool) -> bool {
true #(&& <#field_types>::matches_component_set(&state.#field_idents, _set_contains_id))*
true #(&& <#field_types>::matches_component_set(&state.#named_field_idents, _set_contains_id))*
}
}
};
Expand All @@ -339,7 +381,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
#[doc = "`]."]
#[automatically_derived]
#visibility struct #read_only_struct_name #user_impl_generics #user_where_clauses {
#( #field_idents: #read_only_field_types, )*
#( #named_field_idents: #read_only_field_types, )*
#(#(#ignored_field_attrs)* #ignored_field_visibilities #ignored_field_idents: #ignored_field_types,)*
}

Expand Down Expand Up @@ -386,8 +428,8 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
#[doc = "`], used for caching."]
#[automatically_derived]
#visibility struct #state_struct_name #user_impl_generics #user_where_clauses {
#(#field_idents: <#field_types as #path::query::WorldQuery>::State,)*
#(#ignored_field_idents: #ignored_field_types,)*
#(#named_field_idents: <#field_types as #path::query::WorldQuery>::State,)*
#(#ignored_named_field_idents: #ignored_field_types,)*
}

#mutable_impl
Expand Down
32 changes: 31 additions & 1 deletion crates/bevy_ecs/src/query/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ use std::{cell::UnsafeCell, marker::PhantomData};
/// must implement [`Default`] and will be initialized to the default value as defined
/// by the trait.
///
/// The derive macro only supports regular structs (structs with named fields).
/// The derive macro only supports structs.
///
/// ```
/// # use bevy_ecs::prelude::*;
Expand Down Expand Up @@ -1388,3 +1388,33 @@ unsafe impl<Q: WorldQuery> WorldQuery for NopWorldQuery<Q> {

/// SAFETY: `NopFetch` never accesses any data
unsafe impl<Q: WorldQuery> ReadOnlyWorldQuery for NopWorldQuery<Q> {}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
self as bevy_ecs,
system::{assert_is_system, Query},
};

#[derive(Component)]
struct A;

#[derive(Component)]
struct B;

#[test]
fn world_query_struct_variants() {
#[derive(WorldQuery)]
pub struct NamedQuery {
id: Entity,
a: &'static A,
}

#[derive(WorldQuery)]
pub struct TupleQuery(&'static A, &'static B);

fn my_system(_: Query<(NamedQuery, TupleQuery)>) {}
assert_is_system(my_system);
}
}