From 02312eff4c57becc2f4ffef4707dc72dc85e8e7f Mon Sep 17 00:00:00 2001 From: mvlabat Date: Mon, 23 Aug 2021 00:37:22 +0300 Subject: [PATCH] Implement Fetch and FetchFilter derive macros --- Cargo.toml | 4 + crates/bevy_ecs/macros/src/lib.rs | 724 +++++++++++++++++++++++++++- crates/bevy_ecs/src/query/fetch.rs | 90 ++++ crates/bevy_ecs/src/query/filter.rs | 26 + examples/README.md | 1 + examples/ecs/fetch.rs | 93 ++++ 6 files changed, 936 insertions(+), 2 deletions(-) create mode 100644 examples/ecs/fetch.rs diff --git a/Cargo.toml b/Cargo.toml index 0fafd91be79705..d9ac54c8de80b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -270,6 +270,10 @@ path = "examples/ecs/component_change_detection.rs" name = "event" path = "examples/ecs/event.rs" +[[example]] +name = "fetch" +path = "examples/ecs/fetch.rs" + [[example]] name = "fixed_timestep" path = "examples/ecs/fixed_timestep.rs" diff --git a/crates/bevy_ecs/macros/src/lib.rs b/crates/bevy_ecs/macros/src/lib.rs index 8371de3e4b1605..9d88707e0eff3d 100644 --- a/crates/bevy_ecs/macros/src/lib.rs +++ b/crates/bevy_ecs/macros/src/lib.rs @@ -9,8 +9,9 @@ use syn::{ parse_macro_input, punctuated::Punctuated, token::Comma, - Data, DataStruct, DeriveInput, Field, Fields, GenericParam, Ident, Index, LitInt, Path, Result, - Token, + Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam, Ident, + ImplGenerics, Index, Lifetime, LifetimeDef, LitInt, Path, PathArguments, Result, Token, Type, + TypeGenerics, TypePath, TypeReference, WhereClause, }; struct AllTuples { @@ -423,6 +424,411 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream { }) } +static READONLY_ATTRIBUTE_NAME: &str = "readonly"; +static FILTER_ATTRIBUTE_NAME: &str = "filter"; + +/// Implement `WorldQuery` to use a struct as a parameter in a query +#[proc_macro_derive(Fetch, attributes(readonly, filter))] +pub fn derive_fetch(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + + let FetchImplTokens { + struct_name, + fetch_struct_name, + state_struct_name, + fetch_trait_punctuated_lifetimes, + impl_generics, + ty_generics, + where_clause, + struct_has_world_lt, + world_lt, + state_lt, + } = fetch_impl_tokens(&ast); + + // Fetch's HRTBs require this hack to make the implementation compile. I don't fully understand + // why this works though. If anyone's curious enough to try to find a better work-around, I'll + // leave playground links here: + // - https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=da5e260a5c2f3e774142d60a199e854a (this fails) + // - https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=802517bb3d8f83c45ee8c0be360bb250 (this compiles) + let mut fetch_generics = ast.generics.clone(); + fetch_generics.params.insert(0, state_lt); + if !struct_has_world_lt { + fetch_generics.params.insert(0, world_lt); + } + fetch_generics + .params + .push(GenericParam::Lifetime(LifetimeDef::new(Lifetime::new( + "'fetch", + Span::call_site(), + )))); + let (fetch_impl_generics, _, _) = fetch_generics.split_for_impl(); + let mut fetch_generics = ast.generics.clone(); + if struct_has_world_lt { + *fetch_generics.params.first_mut().unwrap() = + GenericParam::Lifetime(LifetimeDef::new(Lifetime::new("'fetch", Span::call_site()))); + } + let (_, fetch_ty_generics, _) = fetch_generics.split_for_impl(); + + let path = bevy_ecs_path(); + + let fields = match &ast.data { + Data::Struct(DataStruct { + fields: Fields::Named(fields), + .. + }) => &fields.named, + _ => panic!("Expected a struct with named fields"), + }; + + let mut is_read_only = true; + + let mut phantom_field_idents = Vec::new(); + let mut phantom_field_types = Vec::new(); + let mut field_idents = Vec::new(); + let mut filter_field_idents = Vec::new(); + let mut non_filter_field_idents = Vec::new(); + let mut query_types = Vec::new(); + let mut fetch_init_types = Vec::new(); + let mut readonly_types_to_assert = Vec::new(); + + let generic_names = ast + .generics + .params + .iter() + .filter_map(|param| match param { + GenericParam::Type(ty) => Some(ty.ident.to_string()), + _ => None, + }) + .collect::>(); + + for field in fields.iter() { + let has_readonly_attribute = field.attrs.iter().any(|attr| { + attr.path + .get_ident() + .map_or(false, |ident| ident == READONLY_ATTRIBUTE_NAME) + }); + let filter_type = field + .attrs + .iter() + .find(|attr| { + attr.path + .get_ident() + .map_or(false, |ident| ident == FILTER_ATTRIBUTE_NAME) + }) + .map(|filter| { + filter + .parse_args::() + .expect("Expected a filter type (example: `#[filter(With)]`)") + }); + let is_filter = filter_type.is_some(); + + let WorldQueryFieldTypeInfo { + query_type, + fetch_init_type: init_type, + is_read_only: field_is_read_only, + is_phantom, + readonly_types_to_assert: field_readonly_types_to_assert, + } = read_world_query_field_type_info( + &field.ty, + false, + filter_type, + has_readonly_attribute, + &generic_names, + ); + + let field_ident = field.ident.as_ref().unwrap().clone(); + if is_phantom { + phantom_field_idents.push(field_ident.clone()); + phantom_field_types.push(field.ty.clone()); + } else if is_filter { + field_idents.push(field_ident.clone()); + filter_field_idents.push(field_ident.clone()); + query_types.push(query_type); + fetch_init_types.push(init_type); + } else { + field_idents.push(field_ident.clone()); + non_filter_field_idents.push(field_ident.clone()); + query_types.push(query_type); + fetch_init_types.push(init_type); + } + is_read_only = is_read_only && field_is_read_only; + readonly_types_to_assert.extend(field_readonly_types_to_assert.into_iter()); + } + + let read_only_impl = if is_read_only { + quote! { + /// SAFETY: each item in the struct is read only + unsafe impl #impl_generics #path::query::ReadOnlyFetch for #fetch_struct_name #ty_generics #where_clause {} + + // Statically checks that the safety guarantee holds true indeed. We need this to make + // sure that we don't compile ReadOnlyFetch if our struct contains nested WorldQuery + // that don't implement it. + #[allow(dead_code)] + const _: () = { + fn assert_readonly() {} + + // We generate a readonly assertion for every type that isn't &T, &mut T, Option<&T> or Option<&mut T> + fn assert_all #impl_generics () #where_clause { + #(assert_readonly::<<#readonly_types_to_assert as #path::query::WorldQuery>::Fetch>();)* + } + }; + } + } else { + quote! {} + }; + + let tokens = TokenStream::from(quote! { + struct #fetch_struct_name #impl_generics #where_clause { + #(#field_idents: <#query_types as #path::query::WorldQuery>::Fetch,)* + } + + struct #state_struct_name #impl_generics #where_clause { + #(#field_idents: <#query_types as #path::query::WorldQuery>::State,)* + } + + impl #fetch_impl_generics #path::query::Fetch<#fetch_trait_punctuated_lifetimes> for #fetch_struct_name #fetch_ty_generics #where_clause { + type Item = #struct_name #ty_generics; + type State = #state_struct_name #fetch_ty_generics; + + unsafe fn init(_world: &#path::world::World, state: &Self::State, _last_change_tick: u32, _change_tick: u32) -> Self { + #fetch_struct_name { + #(#field_idents: <#fetch_init_types as #path::query::WorldQuery>::Fetch::init(_world, &state.#field_idents, _last_change_tick, _change_tick),)* + } + } + + #[inline] + fn is_dense(&self) -> bool { + true #(&& self.#field_idents.is_dense())* + } + + #[inline] + unsafe fn set_archetype(&mut self, _state: &Self::State, _archetype: &#path::archetype::Archetype, _tables: &#path::storage::Tables) { + #(self.#field_idents.set_archetype(&_state.#field_idents, _archetype, _tables);)* + } + + #[inline] + unsafe fn set_table(&mut self, _state: &Self::State, _table: &#path::storage::Table) { + #(self.#field_idents.set_table(&_state.#field_idents, _table);)* + } + + #[inline] + unsafe fn table_fetch(&mut self, _table_row: usize) -> Self::Item { + use #path::query::FilterFetch; + #struct_name { + #(#non_filter_field_idents: self.#non_filter_field_idents.table_fetch(_table_row),)* + #(#filter_field_idents: self.#filter_field_idents.table_filter_fetch(_table_row),)* + #(#phantom_field_idents: Default::default(),)* + } + } + + #[inline] + unsafe fn archetype_fetch(&mut self, _archetype_index: usize) -> Self::Item { + use #path::query::FilterFetch; + #struct_name { + #(#non_filter_field_idents: self.#non_filter_field_idents.archetype_fetch(_archetype_index),)* + #(#filter_field_idents: self.#filter_field_idents.archetype_filter_fetch(_archetype_index),)* + #(#phantom_field_idents: Default::default(),)* + } + } + } + + // SAFETY: update_component_access and update_archetype_component_access are called for each item in the struct + unsafe impl #impl_generics #path::query::FetchState for #state_struct_name #ty_generics #where_clause { + fn init(world: &mut #path::world::World) -> Self { + #state_struct_name { + #(#field_idents: <#query_types as #path::query::WorldQuery>::State::init(world),)* + } + } + + fn update_component_access(&self, _access: &mut #path::query::FilteredAccess<#path::component::ComponentId>) { + #(self.#field_idents.update_component_access(_access);)* + } + + fn update_archetype_component_access(&self, _archetype: &#path::archetype::Archetype, _access: &mut #path::query::Access<#path::archetype::ArchetypeComponentId>) { + #(self.#field_idents.update_archetype_component_access(_archetype, _access);)* + } + + fn matches_archetype(&self, _archetype: &#path::archetype::Archetype) -> bool { + true #(&& self.#field_idents.matches_archetype(_archetype))* + } + + fn matches_table(&self, _table: &#path::storage::Table) -> bool { + true #(&& self.#field_idents.matches_table(_table))* + } + } + + impl #impl_generics #path::query::WorldQuery for #struct_name #ty_generics #where_clause { + type Fetch = #fetch_struct_name #ty_generics; + type State = #state_struct_name #ty_generics; + } + + #read_only_impl + }); + tokens +} + +/// Implement `FilterFetch` to use a struct as a filter parameter in a query +#[proc_macro_derive(FilterFetch)] +pub fn derive_fetch_filter(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + + let FetchImplTokens { + struct_name, + fetch_struct_name, + state_struct_name, + fetch_trait_punctuated_lifetimes, + impl_generics, + ty_generics, + where_clause, + struct_has_world_lt, + world_lt, + state_lt, + } = fetch_impl_tokens(&ast); + + // Fetch's HRTBs require this hack to make the implementation compile. I don't fully understand + // why this works though. If anyone's curious enough to try to find a better work-around, I'll + // leave playground links here: + // - https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=da5e260a5c2f3e774142d60a199e854a (this fails) + // - https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=802517bb3d8f83c45ee8c0be360bb250 (this compiles) + let mut fetch_generics = ast.generics.clone(); + fetch_generics.params.insert(0, state_lt); + if !struct_has_world_lt { + fetch_generics.params.insert(0, world_lt); + } + fetch_generics + .params + .push(GenericParam::Lifetime(LifetimeDef::new(Lifetime::new( + "'fetch", + Span::call_site(), + )))); + let (fetch_impl_generics, _, _) = fetch_generics.split_for_impl(); + let mut fetch_generics = ast.generics.clone(); + if struct_has_world_lt { + *fetch_generics.params.first_mut().unwrap() = + GenericParam::Lifetime(LifetimeDef::new(Lifetime::new("'fetch", Span::call_site()))); + } + let (_, fetch_ty_generics, _) = fetch_generics.split_for_impl(); + + let path = bevy_ecs_path(); + + let fields = match &ast.data { + Data::Struct(DataStruct { + fields: Fields::Named(fields), + .. + }) => &fields.named, + _ => panic!("Expected a struct with named fields"), + }; + + let mut phantom_field_idents = Vec::new(); + let mut phantom_field_types = Vec::new(); + let mut field_idents = Vec::new(); + let mut field_types = Vec::new(); + + for field in fields.iter() { + let is_phantom = match &field.ty { + Type::Path(ty_path) => { + let last_segment = ty_path.path.segments.last().unwrap(); + last_segment.ident == "PhantomData" + } + _ => false, + }; + + let field_ident = field.ident.as_ref().unwrap().clone(); + if is_phantom { + phantom_field_idents.push(field_ident.clone()); + phantom_field_types.push(field.ty.clone()); + } else { + field_idents.push(field_ident.clone()); + field_types.push(field.ty.clone()); + } + } + + let tokens = TokenStream::from(quote! { + struct #fetch_struct_name #impl_generics #where_clause { + #(#field_idents: <#field_types as #path::query::WorldQuery>::Fetch,)* + #(#phantom_field_idents: #phantom_field_types,)* + } + + struct #state_struct_name #impl_generics #where_clause { + #(#field_idents: <#field_types as #path::query::WorldQuery>::State,)* + #(#phantom_field_idents: #phantom_field_types,)* + } + + impl #fetch_impl_generics #path::query::Fetch<#fetch_trait_punctuated_lifetimes> for #fetch_struct_name #fetch_ty_generics #where_clause { + type Item = bool; + type State = #state_struct_name #fetch_ty_generics; + + unsafe fn init(_world: &#path::world::World, state: &Self::State, _last_change_tick: u32, _change_tick: u32) -> Self { + #fetch_struct_name { + #(#field_idents: <#field_types as #path::query::WorldQuery>::Fetch::init(_world, &state.#field_idents, _last_change_tick, _change_tick),)* + #(#phantom_field_idents: Default::default(),)* + } + } + + #[inline] + fn is_dense(&self) -> bool { + true #(&& self.#field_idents.is_dense())* + } + + #[inline] + unsafe fn set_archetype(&mut self, _state: &Self::State, _archetype: &#path::archetype::Archetype, _tables: &#path::storage::Tables) { + #(self.#field_idents.set_archetype(&_state.#field_idents, _archetype, _tables);)* + } + + #[inline] + unsafe fn set_table(&mut self, _state: &Self::State, _table: &#path::storage::Table) { + #(self.#field_idents.set_table(&_state.#field_idents, _table);)* + } + + #[inline] + unsafe fn table_fetch(&mut self, _table_row: usize) -> Self::Item { + use #path::query::FilterFetch; + true #(&& self.#field_idents.table_filter_fetch(_table_row))* + } + + #[inline] + unsafe fn archetype_fetch(&mut self, _archetype_index: usize) -> Self::Item { + use #path::query::FilterFetch; + true #(&& self.#field_idents.archetype_filter_fetch(_archetype_index))* + } + } + + // SAFETY: update_component_access and update_archetype_component_access are called for each item in the struct + unsafe impl #impl_generics #path::query::FetchState for #state_struct_name #ty_generics #where_clause { + fn init(world: &mut #path::world::World) -> Self { + #state_struct_name { + #(#field_idents: <#field_types as #path::query::WorldQuery>::State::init(world),)* + #(#phantom_field_idents: Default::default(),)* + } + } + + fn update_component_access(&self, _access: &mut #path::query::FilteredAccess<#path::component::ComponentId>) { + #(self.#field_idents.update_component_access(_access);)* + } + + fn update_archetype_component_access(&self, _archetype: &#path::archetype::Archetype, _access: &mut #path::query::Access<#path::archetype::ArchetypeComponentId>) { + #(self.#field_idents.update_archetype_component_access(_archetype, _access);)* + } + + fn matches_archetype(&self, _archetype: &#path::archetype::Archetype) -> bool { + true #(&& self.#field_idents.matches_archetype(_archetype))* + } + + fn matches_table(&self, _table: &#path::storage::Table) -> bool { + true #(&& self.#field_idents.matches_table(_table))* + } + } + + impl #impl_generics #path::query::WorldQuery for #struct_name #ty_generics #where_clause { + type Fetch = #fetch_struct_name #ty_generics; + type State = #state_struct_name #ty_generics; + } + + /// SAFETY: each item in the struct is a fetch filter and thus is read only + unsafe impl #impl_generics #path::query::ReadOnlyFetch for #fetch_struct_name #ty_generics #where_clause {} + }); + tokens +} + #[proc_macro_derive(SystemLabel)] pub fn derive_system_label(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -468,6 +874,320 @@ fn derive_label(input: DeriveInput, label_type: Ident) -> TokenStream2 { } } +struct FetchImplTokens<'a> { + struct_name: Ident, + fetch_struct_name: Ident, + state_struct_name: Ident, + fetch_trait_punctuated_lifetimes: Punctuated, + impl_generics: ImplGenerics<'a>, + ty_generics: TypeGenerics<'a>, + where_clause: Option<&'a WhereClause>, + struct_has_world_lt: bool, + world_lt: GenericParam, + state_lt: GenericParam, +} + +fn fetch_impl_tokens(ast: &DeriveInput) -> FetchImplTokens { + let world_lt = ast.generics.params.first().and_then(|param| match param { + lt @ GenericParam::Lifetime(_) => Some(lt.clone()), + _ => None, + }); + let struct_has_world_lt = world_lt.is_some(); + let world_lt = world_lt.unwrap_or_else(|| { + GenericParam::Lifetime(LifetimeDef::new(Lifetime::new("'world", Span::call_site()))) + }); + let state_lt = + GenericParam::Lifetime(LifetimeDef::new(Lifetime::new("'state", Span::call_site()))); + + let mut fetch_trait_punctuated_lifetimes = Punctuated::<_, Token![,]>::new(); + fetch_trait_punctuated_lifetimes.push(world_lt.clone()); + fetch_trait_punctuated_lifetimes.push(state_lt.clone()); + + let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); + + let struct_name = ast.ident.clone(); + let fetch_struct_name = Ident::new(&format!("{}Fetch", struct_name), Span::call_site()); + let state_struct_name = Ident::new(&format!("{}State", struct_name), Span::call_site()); + + FetchImplTokens { + struct_name, + fetch_struct_name, + state_struct_name, + fetch_trait_punctuated_lifetimes, + impl_generics, + ty_generics, + where_clause, + struct_has_world_lt, + world_lt, + state_lt, + } +} + +struct WorldQueryFieldTypeInfo { + /// We convert `Mut` to `&mut T` (because this is the type that implements `WorldQuery`) + /// and store it here. + query_type: Type, + /// The same as `query_type` but with `'fetch` lifetime. + fetch_init_type: Type, + is_read_only: bool, + is_phantom: bool, + readonly_types_to_assert: Vec, +} + +fn read_world_query_field_type_info( + ty: &Type, + is_tuple_element: bool, + filter_type: Option, + has_readonly_attribute: bool, + generic_names: &[String], +) -> WorldQueryFieldTypeInfo { + let mut query_type = ty.clone(); + let mut fetch_init_type = ty.clone(); + let mut is_read_only = true; + let mut is_phantom = false; + let mut readonly_types_to_assert = Vec::new(); + + match (ty, &mut fetch_init_type) { + (Type::Path(path), Type::Path(path_init)) => { + if path.qself.is_some() { + // There's a risk that it contains a generic parameter that we can't test + // whether it's readonly or not. + panic!("Self type qualifiers aren't supported"); + } + + let segment = path.path.segments.last().unwrap(); + if segment.ident == "Option" { + // We expect that `Option` stores either `&T` or `Mut`. + let ty = match &segment.arguments { + PathArguments::AngleBracketed(args) => { + args.args.last().and_then(|arg| match arg { + GenericArgument::Type(ty) => Some(ty), + _ => None, + }) + } + _ => None, + }; + match ty.expect("Option type is expected to have generic arguments") { + // If it's a read-only reference, we just update the lifetime for `fetch_init_type` to `'fetch`. + Type::Reference(reference) => { + if reference.mutability.is_some() { + panic!("Invalid reference type: use `Mut` instead of `&mut T`"); + } + match &mut path_init.path.segments.last_mut().unwrap().arguments { + PathArguments::AngleBracketed(args) => { + match args.args.last_mut().unwrap() { + GenericArgument::Type(Type::Reference(ty)) => ty.lifetime = Some(Lifetime::new("'fetch", Span::call_site())), + _ => unreachable!(), + } + } + _ => unreachable!(), + } + } + // If it's a mutable reference, we set `query_type` and `fetch_init_type` to `&mut T`, + // we also update the lifetime for `fetch_init_type` to `'fetch`. + Type::Path(path) => { + assert_not_generic(&path, generic_names); + + let segment = path.path.segments.last().unwrap(); + let ty_ident = &segment.ident; + if ty_ident == "Mut" { + is_read_only = false; + let (mut_lifetime, mut_ty) = match &segment.arguments { + PathArguments::AngleBracketed(args) => { + (args.args.first().and_then(|arg| { match arg { + GenericArgument::Lifetime(lifetime) => Some(lifetime.clone()), + _ => None, + }}).expect("Mut is expected to have a lifetime"), + args.args.last().and_then(|arg| { match arg { + GenericArgument::Type(ty) => Some(ty.clone()), + _ => None, + }}).expect("Mut is expected to have a lifetime")) + } + _ => panic!("Mut type is expected to have generic arguments") + }; + + match query_type { + Type::Path(ref mut path) => { + let segment = path.path.segments.last_mut().unwrap(); + match segment.arguments { + PathArguments::AngleBracketed(ref mut args) => { + match args.args.last_mut().unwrap() { + GenericArgument::Type(ty) => { + *ty = Type::Reference(TypeReference { + and_token: Token![&](Span::call_site()), + lifetime: Some(mut_lifetime), + mutability: Some(Token![mut](Span::call_site())), + elem: Box::new(mut_ty.clone()), + }); + } + _ => unreachable!() + } + } + _ => unreachable!() + } + } + _ => unreachable!() + } + + let segment = path_init.path.segments.last_mut().unwrap(); + match segment.arguments { + PathArguments::AngleBracketed(ref mut args) => { + match args.args.last_mut().unwrap() { + GenericArgument::Type(ty) => { + *ty = Type::Reference(TypeReference { + and_token: Token![&](Span::call_site()), + lifetime: Some(Lifetime::new("'fetch", Span::call_site())), + mutability: Some(Token![mut](Span::call_site())), + elem: Box::new(mut_ty), + }); + } + _ => unreachable!() + } + } + _ => unreachable!() + } + } else { + panic!("Option type is expected to have a reference value (`Option<&T>` or `Option>`)"); + } + } + _ => panic!("Option type is expected to have a reference value (`Option<&T>` or `Option>`)"), + } + } else if segment.ident == "Mut" { + is_read_only = false; + // If it's a mutable reference, we set `query_type` and `fetch_init_type` to `&mut T`, + // we also update the lifetime for `fetch_init_type` to `'fetch`. + let (mut_lifetime, mut_ty) = match &segment.arguments { + PathArguments::AngleBracketed(args) => { + let lt = args.args.first().and_then(|arg| { match arg { + GenericArgument::Lifetime(lifetime) => Some(lifetime.clone()), + _ => None, + }}).expect("`Mut` is expected to have a lifetime"); + let ty = args.args.last().and_then(|arg| { match arg { + GenericArgument::Type(ty) => Some(ty.clone()), + _ => None, + }}).expect("`Mut` is expected to have a lifetime"); + (lt, ty) + } + _ => panic!("`Mut` is expected to have generic arguments") + }; + + query_type = Type::Reference(TypeReference { + and_token: Token![&](Span::call_site()), + lifetime: Some(mut_lifetime), + mutability: Some(Token![mut](Span::call_site())), + elem: Box::new(mut_ty.clone()), + }); + fetch_init_type = Type::Reference(TypeReference { + and_token: Token![&](Span::call_site()), + lifetime: Some(Lifetime::new("'fetch", Span::call_site())), + mutability: Some(Token![mut](Span::call_site())), + elem: Box::new(mut_ty), + }); + } else if segment.ident == "bool" { + if is_tuple_element { + panic!("Invalid tuple element: bool"); + } + fetch_init_type = filter_type.expect("Field type is `bool` but no `filter` attribute is found (example: `#[filter(With)]`)"); + query_type = fetch_init_type.clone(); + } else if segment.ident == "With" || segment.ident == "Without" || segment.ident == "Or" || segment.ident == "Added" || segment.ident == "Changed" { + panic!("Invalid filter type: use `bool` field type and specify the filter with `#[filter({})]` attribute", segment.ident.to_string()); + } else if segment.ident == "PhantomData" { + if is_tuple_element { + panic!("Invalid tuple element: PhantomData"); + } + is_phantom = true; + } else if segment.ident != "Entity" { + assert_not_generic(&path, generic_names); + + match &mut path_init.path.segments.last_mut().unwrap().arguments { + PathArguments::AngleBracketed(args) => { + match args.args.first_mut() { + Some(GenericArgument::Lifetime(lt)) => { + *lt = Lifetime::new("'fetch", Span::call_site()); + } + _ => {}, + } + } + _ => {}, + } + + // If there's no `filter` attribute, we assume that it's a nested struct that implements `Fetch`. + if filter_type.is_none() { + // If a user marks the field with the `readonly` attribute, we'll insert + // a function call (no calls will happen in runtime), that will check that + // the type implements `ReadOnlyFetch` indeed. + // We can't allow ourselves to implement `ReadOnlyFetch` for the current struct + // if we are not sure that all members implement it. + if has_readonly_attribute { + readonly_types_to_assert.push(path.clone()); + } else { + is_read_only = false; + } + } + } + } + (Type::Reference(reference), Type::Reference(init_reference)) => { + if reference.mutability.is_some() { + panic!("Invalid reference type: use `Mut` instead of `&mut T`"); + } + init_reference.lifetime = Some(Lifetime::new("'fetch", Span::call_site())); + } + (Type::Tuple(tuple), Type::Tuple(init_tuple)) => { + let mut query_tuple_elems = tuple.elems.clone(); + query_tuple_elems.clear(); + let mut fetch_init_tuple_elems = query_tuple_elems.clone(); + for ty in tuple.elems.iter() { + let WorldQueryFieldTypeInfo { + query_type, + fetch_init_type, + is_read_only: elem_is_read_only, + is_phantom: _, + readonly_types_to_assert: elem_readonly_types_to_assert, + } = read_world_query_field_type_info( + ty, + true, + None, + has_readonly_attribute, + generic_names, + ); + query_tuple_elems.push(query_type); + fetch_init_tuple_elems.push(fetch_init_type); + is_read_only = is_read_only && elem_is_read_only; + readonly_types_to_assert.extend(elem_readonly_types_to_assert.into_iter()); + } + match query_type { + Type::Tuple(ref mut tuple) => { + tuple.elems = query_tuple_elems; + } + _ => unreachable!(), + } + init_tuple.elems = fetch_init_tuple_elems; + } + _ => panic!("Only the following types (or their tuples) are supported for WorldQuery: &T, &mut T, Option<&T>, Option<&mut T>, Entity, or other structs that implement WorldQuery"), + } + + return WorldQueryFieldTypeInfo { + query_type, + fetch_init_type, + is_read_only, + is_phantom, + readonly_types_to_assert, + }; +} + +fn assert_not_generic(type_path: &TypePath, generic_names: &[String]) { + // `get_ident` returns Some if it consists of a single segment, in this case it + // makes sense to ensure that it's not a generic. + if let Some(ident) = type_path.path.get_ident() { + let is_generic = generic_names + .iter() + .any(|generic_name| ident == generic_name.as_str()); + if is_generic { + panic!("Only references to generic types are supported: i.e. instead of `component: T`, use `component: &T` or `component: Mut` (optional references are supported as well)"); + } + } +} + fn bevy_ecs_path() -> syn::Path { BevyManifest::default().get_path("bevy_ecs") } diff --git a/crates/bevy_ecs/src/query/fetch.rs b/crates/bevy_ecs/src/query/fetch.rs index 9e20c208761e78..dccea3eeab1f3b 100644 --- a/crates/bevy_ecs/src/query/fetch.rs +++ b/crates/bevy_ecs/src/query/fetch.rs @@ -8,6 +8,7 @@ use crate::{ world::{Mut, World}, }; use bevy_ecs_macros::all_tuples; +pub use bevy_ecs_macros::{Fetch, FilterFetch}; use std::{ cell::UnsafeCell, marker::PhantomData, @@ -45,6 +46,95 @@ pub trait WorldQuery { type State: FetchState; } +/// # Derive +/// +/// This trait can be derived with the [`derive@super::Fetch`] macro. +/// To do so, all fields in the struct must themselves impl [`WorldQuery`]. +/// +/// ``` +/// # use bevy_ecs::prelude::*; +/// use bevy_ecs::query::Fetch; +/// +/// #[derive(Fetch)] +/// struct MyQuery<'w> { +/// foo: &'w u32, +/// bar: Mut<'w, i32>, +/// } +/// +/// fn my_system(mut query: Query) { +/// for q in query.iter_mut() { +/// q.foo; +/// } +/// } +/// +/// # my_system.system(); +/// ``` +/// +/// ## Usage with filters +/// +/// All filter members must be marked with `filter` attribute and have `bool` type. +/// +/// ``` +/// # use bevy_ecs::prelude::*; +/// use bevy_ecs::query::Fetch; +/// +/// #[derive(Fetch)] +/// struct MyQuery<'w> { +/// foo: &'w u32, +/// bar: Mut<'w, i32>, +/// #[filter(Changed)] +/// foo_is_changed: bool, +/// } +/// ``` +/// +/// ## Read-only queries +/// +/// All queries that access components non-mutably are read-only by default, with the exception +/// of nested custom queries (containing members that implement [`Fetch`] with the derive macro). +/// In order to compile a nested query as a read-only one, such members must be marked with +/// the `readonly` attribute. +/// +/// ``` +/// # use bevy_ecs::prelude::*; +/// use bevy_ecs::query::{Fetch, ReadOnlyFetch, WorldQuery}; +/// +/// #[derive(Fetch)] +/// struct FooQuery<'w> { +/// foo: &'w u32, +/// #[readonly] +/// bar_query: BarQuery<'w>, +/// } +/// +/// #[derive(Fetch)] +/// struct BarQuery<'w> { +/// bar: &'w u32, +/// } +/// +/// fn assert_readonly() {} +/// +/// assert_readonly::<::Fetch>(); +/// ``` +/// +/// **Note** that if you mark a field that doesn't implement `ReadOnlyFetch` as `readonly`, the +/// compilation will fail. We insert static checks as in the example above for every nested query +/// marked as `readonly`. (They neither affect the runtime, nor pollute your local namespace.) +/// +/// ```compile_fail +/// # use bevy_ecs::prelude::*; +/// use bevy_ecs::query::{Fetch, ReadOnlyFetch, WorldQuery}; +/// +/// #[derive(Fetch)] +/// struct FooQuery<'w> { +/// foo: &'w u32, +/// #[readonly] +/// bar_query: BarQuery<'w>, +/// } +/// +/// #[derive(Fetch)] +/// struct BarQuery<'w> { +/// bar: Mut<'w, u32>, +/// } +/// ``` pub trait Fetch<'world, 'state>: Sized { type Item; type State: FetchState; diff --git a/crates/bevy_ecs/src/query/filter.rs b/crates/bevy_ecs/src/query/filter.rs index 69dd298cd31e00..981d41dcef3b4b 100644 --- a/crates/bevy_ecs/src/query/filter.rs +++ b/crates/bevy_ecs/src/query/filter.rs @@ -11,6 +11,32 @@ use std::{cell::UnsafeCell, marker::PhantomData, ptr}; /// Extension trait for [`Fetch`] containing methods used by query filters. /// This trait exists to allow "short circuit" behaviors for relevant query filter fetches. +/// +/// ## Derive +/// +/// This trait can be derived with the [`derive@super::FilterFetch`] macro. +/// To do so, all fields in the struct must be filters themselves (their [`WorldQuery::Fetch`] +/// associated types should implement [`FilterFetch`]). +/// +/// ``` +/// # use bevy_ecs::prelude::*; +/// use bevy_ecs::{query::FilterFetch, component::Component}; +/// +/// #[derive(FilterFetch)] +/// struct MyFilter { +/// _u_16: With, +/// _u_32: With, +/// _or: Or<(With, Changed, Added)>, +/// _generic_tuple: (With, Without

), +/// _tp: std::marker::PhantomData<(T, P)>, +/// } +/// +/// fn my_system(query: Query>) { +/// for _ in query.iter() {} +/// } +/// +/// # my_system.system(); +/// ``` pub trait FilterFetch: for<'w, 's> Fetch<'w, 's> { /// # Safety /// diff --git a/examples/README.md b/examples/README.md index 7796afdb162237..3a2d7cbda92b07 100644 --- a/examples/README.md +++ b/examples/README.md @@ -158,6 +158,7 @@ Example | File | Description `ecs_guide` | [`ecs/ecs_guide.rs`](./ecs/ecs_guide.rs) | Full guide to Bevy's ECS `component_change_detection` | [`ecs/component_change_detection.rs`](./ecs/component_change_detection.rs) | Change detection on components `event` | [`ecs/event.rs`](./ecs/event.rs) | Illustrates event creation, activation, and reception +`fetch` | [`ecs/fetch.rs`](./ecs/fetch.rs) | Illustrates creating custom queries and query filters with `Fetch` and `FilterFetch` `fixed_timestep` | [`ecs/fixed_timestep.rs`](./ecs/fixed_timestep.rs) | Shows how to create systems that run every fixed timestep, rather than every tick `hierarchy` | [`ecs/hierarchy.rs`](./ecs/hierarchy.rs) | Creates a hierarchy of parents and children entities `iter_combinations` | [`ecs/iter_combinations.rs`](./ecs/iter_combinations.rs) | Shows how to iterate over combinations of query results. diff --git a/examples/ecs/fetch.rs b/examples/ecs/fetch.rs new file mode 100644 index 00000000000000..438233e2bed056 --- /dev/null +++ b/examples/ecs/fetch.rs @@ -0,0 +1,93 @@ +use bevy::ecs::component::Component; +use bevy::{ + ecs::query::{Fetch, FilterFetch}, + prelude::*, +}; +use std::marker::PhantomData; + +fn main() { + App::new() + .add_startup_system(spawn) + .add_system(print_nums) + .add_system(print_nums_readonly) + .run(); +} + +#[derive(Fetch, Debug)] +struct NumQuery<'w, T: Component, P: Component> { + entity: Entity, + u: UNumQuery<'w>, + i: INumQuery<'w>, + generic: MutGenericQuery<'w, T, P>, + #[filter(NumQueryFilter)] + filter: bool, +} + +// If you want to declare a read-only query that uses nested `Fetch` structs, you need to +// specify `readonly` attribute for the corresponding fields. This will generate static assertions +// that those members implement `ReadOnlyFetch`. +#[derive(Fetch, Debug)] +struct ReadOnlyNumQuery<'w, T: Component, P: Component> { + entity: Entity, + #[readonly] + u: UNumQuery<'w>, + #[readonly] + generic: GenericQuery<'w, T, P>, + #[filter(NumQueryFilter)] + filter: bool, +} + +#[derive(Fetch, Debug)] +struct UNumQuery<'w> { + u_16: &'w u16, + u_32_opt: Option<&'w u32>, +} + +#[derive(Fetch, Debug)] +struct MutGenericQuery<'w, T: Component, P: Component> { + generic: (Mut<'w, T>, Mut<'w, P>), +} + +#[derive(Fetch, Debug)] +struct GenericQuery<'w, T: Component, P: Component> { + generic: (&'w T, &'w P), +} + +#[derive(Fetch, Debug)] +struct INumQuery<'w> { + i_16: Mut<'w, i16>, + i_32_opt: Option>, +} + +#[derive(FilterFetch)] +struct NumQueryFilter { + _u_16: With, + _u_32: With, + _or: Or<(With, Changed, Added)>, + _generic_tuple: (With, With

), + _without: Without>, + _tp: PhantomData<(T, P)>, +} + +fn spawn(mut commands: Commands) { + commands + .spawn() + .insert(1u16) + .insert(2u32) + .insert(3i16) + .insert(4i32) + .insert(5u64) + .insert(6i64); +} + +fn print_nums(mut query: Query, NumQueryFilter>) { + for num in query.iter_mut() { + println!("Print: {:#?}", num); + } +} + +fn print_nums_readonly(query: Query, NumQueryFilter>) { + for num in query.iter() { + println!("Print read-only: {:#?}", num); + } +}