Skip to content

Commit

Permalink
Extend the WorldQuery macro to tuple structs (#8119)
Browse files Browse the repository at this point in the history
# Objective

The `#[derive(WorldQuery)]` macro currently only supports structs with
named fields.

Same motivation as #6957. Remove sharp edges from the derive macro, make
it just work more often.

## Solution

Support tuple structs.

---

## Changelog

+ Added support for tuple structs to the `#[derive(WorldQuery)]` macro.
  • Loading branch information
JoJoJet authored Apr 4, 2023
1 parent 2aaaed7 commit b423e6e
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 37 deletions.
105 changes: 71 additions & 34 deletions crates/bevy_ecs/macros/src/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use bevy_macro_utils::ensure_no_collision;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};
use quote::{format_ident, quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
punctuated::Punctuated,
Attribute, Data, DataStruct, DeriveInput, Field, Fields,
Attribute, Data, DataStruct, DeriveInput, Field, Index,
};

use crate::bevy_ecs_path;
Expand Down Expand Up @@ -116,34 +116,49 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
fetch_struct_name.clone()
};

let marker_name =
ensure_no_collision(format_ident!("_world_query_derive_marker"), tokens.clone());

// Generate a name for the state struct that doesn't conflict
// with the struct definition.
let state_struct_name = Ident::new(&format!("{struct_name}State"), Span::call_site());
let state_struct_name = ensure_no_collision(state_struct_name, tokens);

let fields = match &ast.data {
Data::Struct(DataStruct {
fields: Fields::Named(fields),
..
}) => &fields.named,
_ => panic!("Expected a struct with named fields"),
let Data::Struct(DataStruct { fields, .. }) = &ast.data else {
return syn::Error::new(
Span::call_site(),
"#[derive(WorldQuery)]` only supports structs",
)
.into_compile_error()
.into()
};

let mut field_attrs = Vec::new();
let mut field_visibilities = Vec::new();
let mut field_idents = Vec::new();
let mut named_field_idents = Vec::new();
let mut field_types = Vec::new();
let mut read_only_field_types = Vec::new();

for field in fields {
for (i, field) in fields.iter().enumerate() {
let attrs = match read_world_query_field_info(field) {
Ok(WorldQueryFieldInfo { attrs }) => attrs,
Err(e) => return e.into_compile_error().into(),
};

let named_field_ident = field
.ident
.as_ref()
.cloned()
.unwrap_or_else(|| format_ident!("f{i}"));
let i = Index::from(i);
let field_ident = field
.ident
.as_ref()
.map_or(quote! { #i }, |i| quote! { #i });
field_idents.push(field_ident);
named_field_idents.push(named_field_ident);
field_attrs.push(attrs);
field_visibilities.push(field.vis.clone());
field_idents.push(field.ident.as_ref().unwrap().clone());
let field_ty = field.ty.clone();
field_types.push(quote!(#field_ty));
read_only_field_types.push(quote!(<#field_ty as #path::query::WorldQuery>::ReadOnly));
Expand Down Expand Up @@ -176,15 +191,34 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
&field_types
};

let item_struct = quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#(#field_attrs)* #field_visibilities #field_idents: <#field_types as #path::query::WorldQuery>::Item<'__w>,)*
}
let item_struct = match fields {
syn::Fields::Named(_) => quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#(#field_attrs)* #field_visibilities #field_idents: <#field_types as #path::query::WorldQuery>::Item<'__w>,)*
}
},
syn::Fields::Unnamed(_) => quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world(
#( #field_visibilities <#field_types as #path::query::WorldQuery>::Item<'__w>, )*
);
},
syn::Fields::Unit => quote! {
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility type #item_struct_name #user_ty_generics_with_world = #struct_name #user_ty_generics;
},
};

let query_impl = quote! {
Expand All @@ -194,7 +228,8 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
#[doc = "`], used to define the world data accessed by this query."]
#[automatically_derived]
#visibility struct #fetch_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#field_idents: <#field_types as #path::query::WorldQuery>::Fetch<'__w>,)*
#(#named_field_idents: <#field_types as #path::query::WorldQuery>::Fetch<'__w>,)*
#marker_name: &'__w (),
}

// SAFETY: `update_component_access` and `update_archetype_component_access` are called on every field
Expand Down Expand Up @@ -223,14 +258,15 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_this_run: #path::component::Tick,
) -> <Self as #path::query::WorldQuery>::Fetch<'__w> {
#fetch_struct_name {
#(#field_idents:
#(#named_field_idents:
<#field_types>::init_fetch(
_world,
&state.#field_idents,
&state.#named_field_idents,
_last_run,
_this_run,
),
)*
#marker_name: &(),
}
}

Expand All @@ -239,8 +275,9 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
) -> <Self as #path::query::WorldQuery>::Fetch<'__w> {
#fetch_struct_name {
#(
#field_idents: <#field_types>::clone_fetch(& _fetch. #field_idents),
#named_field_idents: <#field_types>::clone_fetch(& _fetch. #named_field_idents),
)*
#marker_name: &(),
}
}

Expand All @@ -256,7 +293,7 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_archetype: &'__w #path::archetype::Archetype,
_table: &'__w #path::storage::Table
) {
#(<#field_types>::set_archetype(&mut _fetch.#field_idents, &_state.#field_idents, _archetype, _table);)*
#(<#field_types>::set_archetype(&mut _fetch.#named_field_idents, &_state.#named_field_idents, _archetype, _table);)*
}

/// SAFETY: we call `set_table` for each member that implements `Fetch`
Expand All @@ -266,7 +303,7 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_state: &Self::State,
_table: &'__w #path::storage::Table
) {
#(<#field_types>::set_table(&mut _fetch.#field_idents, &_state.#field_idents, _table);)*
#(<#field_types>::set_table(&mut _fetch.#named_field_idents, &_state.#named_field_idents, _table);)*
}

/// SAFETY: we call `fetch` for each member that implements `Fetch`.
Expand All @@ -277,7 +314,7 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_table_row: #path::storage::TableRow,
) -> <Self as #path::query::WorldQuery>::Item<'__w> {
Self::Item {
#(#field_idents: <#field_types>::fetch(&mut _fetch.#field_idents, _entity, _table_row),)*
#(#field_idents: <#field_types>::fetch(&mut _fetch.#named_field_idents, _entity, _table_row),)*
}
}

Expand All @@ -288,11 +325,11 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_entity: #path::entity::Entity,
_table_row: #path::storage::TableRow,
) -> bool {
true #(&& <#field_types>::filter_fetch(&mut _fetch.#field_idents, _entity, _table_row))*
true #(&& <#field_types>::filter_fetch(&mut _fetch.#named_field_idents, _entity, _table_row))*
}

fn update_component_access(state: &Self::State, _access: &mut #path::query::FilteredAccess<#path::component::ComponentId>) {
#( <#field_types>::update_component_access(&state.#field_idents, _access); )*
#( <#field_types>::update_component_access(&state.#named_field_idents, _access); )*
}

fn update_archetype_component_access(
Expand All @@ -301,18 +338,18 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_access: &mut #path::query::Access<#path::archetype::ArchetypeComponentId>
) {
#(
<#field_types>::update_archetype_component_access(&state.#field_idents, _archetype, _access);
<#field_types>::update_archetype_component_access(&state.#named_field_idents, _archetype, _access);
)*
}

fn init_state(world: &mut #path::world::World) -> #state_struct_name #user_ty_generics {
#state_struct_name {
#(#field_idents: <#field_types>::init_state(world),)*
#(#named_field_idents: <#field_types>::init_state(world),)*
}
}

fn matches_component_set(state: &Self::State, _set_contains_id: &impl Fn(#path::component::ComponentId) -> bool) -> bool {
true #(&& <#field_types>::matches_component_set(&state.#field_idents, _set_contains_id))*
true #(&& <#field_types>::matches_component_set(&state.#named_field_idents, _set_contains_id))*
}
}
};
Expand All @@ -328,7 +365,7 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
#[doc = "`]."]
#[automatically_derived]
#visibility struct #read_only_struct_name #user_impl_generics #user_where_clauses {
#( #field_visibilities #field_idents: #read_only_field_types, )*
#( #field_visibilities #named_field_idents: #read_only_field_types, )*
}

#readonly_state
Expand Down Expand Up @@ -374,7 +411,7 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
#[doc = "`], used for caching."]
#[automatically_derived]
#visibility struct #state_struct_name #user_impl_generics #user_where_clauses {
#(#field_idents: <#field_types as #path::query::WorldQuery>::State,)*
#(#named_field_idents: <#field_types as #path::query::WorldQuery>::State,)*
}

#mutable_impl
Expand Down
31 changes: 28 additions & 3 deletions crates/bevy_ecs/src/query/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ use std::{cell::UnsafeCell, marker::PhantomData};
/// - Methods can be implemented for the query items.
/// - There is no hardcoded limit on the number of elements.
///
/// This trait can only be derived if each field also implements `WorldQuery`.
/// The derive macro only supports regular structs (structs with named fields).
/// This trait can only be derived for structs, if each field also implements `WorldQuery`.
///
/// ```
/// # use bevy_ecs::prelude::*;
Expand Down Expand Up @@ -1468,11 +1467,37 @@ unsafe impl<T: ?Sized> ReadOnlyWorldQuery for PhantomData<T> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::{self as bevy_ecs, system::Query};
use crate::{
self as bevy_ecs,
system::{assert_is_system, Query},
};

#[derive(Component)]
pub struct A;

#[derive(Component)]
pub struct B;

// Tests that each variant of struct can be used as a `WorldQuery`.
#[test]
fn world_query_struct_variants() {
#[derive(WorldQuery)]
pub struct NamedQuery {
id: Entity,
a: &'static A,
}

#[derive(WorldQuery)]
pub struct TupleQuery(&'static A, &'static B);

#[derive(WorldQuery)]
pub struct UnitQuery;

fn my_system(_: Query<(NamedQuery, TupleQuery, UnitQuery)>) {}

assert_is_system(my_system);
}

// Compile test for https://github.com/bevyengine/bevy/pull/8030.
#[test]
fn world_query_phantom_data() {
Expand Down

0 comments on commit b423e6e

Please sign in to comment.