diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index 51cd482b..b506b28b 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -124,6 +124,9 @@ macro_rules! setup_input_struct { } impl $zalsa::SalsaStructInDb for $Struct { + fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + } } impl $Struct { diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index bf9d98f5..e8b8af18 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -141,6 +141,9 @@ macro_rules! setup_interned_struct { } impl $zalsa::SalsaStructInDb for $Struct<'_> { + fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + } } unsafe impl $zalsa::Update for $Struct<'_> { diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 79ec96d5..ce5d313d 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -99,6 +99,9 @@ macro_rules! setup_tracked_fn { $zalsa::IngredientCache::new(); impl $zalsa::SalsaStructInDb for $InternedData<'_> { + fn lookup_ingredient_index(_aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + None + } } impl $zalsa::interned::Configuration for $Configuration { @@ -199,7 +202,19 @@ macro_rules! setup_tracked_fn { aux: &dyn $zalsa::JarAux, first_index: $zalsa::IngredientIndex, ) -> Vec> { + let struct_index = $zalsa::macro_if! { + if $needs_interner { + first_index.successor(0) + } else { + <$InternedData as $zalsa::SalsaStructInDb>::lookup_ingredient_index(aux) + .expect( + "Salsa struct is passed as an argument of a tracked function, but its ingredient hasn't been added!" + ) + } + }; + let fn_ingredient = <$zalsa::function::IngredientImpl<$Configuration>>::new( + struct_index, first_index, aux, ); @@ -219,6 +234,10 @@ macro_rules! setup_tracked_fn { } } } + + fn salsa_struct_type_id(&self) -> Option { + None + } } #[allow(non_local_definitions)] diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index d0d42c6d..a783e376 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -152,6 +152,9 @@ macro_rules! setup_tracked_struct { } impl $zalsa::SalsaStructInDb for $Struct<'_> { + fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + } } impl $zalsa::TrackedStructInDb for $Struct<'_> { diff --git a/src/accumulator.rs b/src/accumulator.rs index b9355419..cd566f40 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -53,6 +53,10 @@ impl Jar for JarImpl { ) -> Vec> { vec![Box::new(>::new(first_index))] } + + fn salsa_struct_type_id(&self) -> Option { + None + } } pub struct IngredientImpl { diff --git a/src/function.rs b/src/function.rs index 07f13d49..b06be486 100644 --- a/src/function.rs +++ b/src/function.rs @@ -126,10 +126,10 @@ impl IngredientImpl where C: Configuration, { - pub fn new(index: IngredientIndex, aux: &dyn JarAux) -> Self { + pub fn new(struct_index: IngredientIndex, index: IngredientIndex, aux: &dyn JarAux) -> Self { Self { index, - memo_ingredient_index: aux.next_memo_ingredient_index(index), + memo_ingredient_index: aux.next_memo_ingredient_index(struct_index, index), lru: Default::default(), deleted_entries: Default::default(), } diff --git a/src/ingredient.rs b/src/ingredient.rs index 383fdc6b..8a46205d 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -23,10 +23,33 @@ pub trait Jar: Any { aux: &dyn JarAux, first_index: IngredientIndex, ) -> Vec>; + + /// If this jar's first ingredient is a salsa struct, return its `TypeId` + fn salsa_struct_type_id(&self) -> Option; } +/// Methods on the Salsa database available to jars while they are creating their ingredients. pub trait JarAux { - fn next_memo_ingredient_index(&self, ingredient_index: IngredientIndex) -> MemoIngredientIndex; + /// Return index of first ingredient from `jar` (based on the dynamic type of `jar`). + /// Returns `None` if the jar has not yet been added. + /// Used by tracked functions to lookup the ingredient index for the salsa struct they take as argument. + fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option; + + /// Returns the memo ingredient index that should be used to attach data from the given tracked function + /// to the given salsa struct (which the fn accepts as argument). + /// + /// The memo ingredient indices for a given function must be distinct from the memo indices + /// of all other functions that take the same salsa struct. + /// + /// # Parameters + /// + /// * `struct_ingredient_index`, the index of the salsa struct the memo will be attached to + /// * `ingredient_index`, the index of the tracked function whose data is stored in the memo + fn next_memo_ingredient_index( + &self, + struct_ingredient_index: IngredientIndex, + ingredient_index: IngredientIndex, + ) -> MemoIngredientIndex; } pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { diff --git a/src/input.rs b/src/input.rs index fdad27ac..4b82ef2f 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,4 +1,8 @@ -use std::{any::Any, fmt, ops::DerefMut}; +use std::{ + any::{Any, TypeId}, + fmt, + ops::DerefMut, +}; pub mod input_field; pub mod setter; @@ -60,6 +64,10 @@ impl Jar for JarImpl { })) .collect() } + + fn salsa_struct_type_id(&self) -> Option { + Some(TypeId::of::<::Struct>()) + } } pub struct IngredientImpl { diff --git a/src/interned.rs b/src/interned.rs index 0c6d32cd..f6767fb1 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -9,6 +9,7 @@ use crate::table::Slot; use crate::zalsa::IngredientIndex; use crate::zalsa_local::QueryOrigin; use crate::{Database, DatabaseKeyIndex, Id}; +use std::any::TypeId; use std::fmt; use std::hash::{BuildHasher, Hash, Hasher}; use std::marker::PhantomData; @@ -92,6 +93,10 @@ impl Jar for JarImpl { ) -> Vec> { vec![Box::new(IngredientImpl::::new(first_index)) as _] } + + fn salsa_struct_type_id(&self) -> Option { + Some(TypeId::of::<::Struct<'static>>()) + } } impl IngredientImpl diff --git a/src/salsa_struct.rs b/src/salsa_struct.rs index fcf7920a..8674dc12 100644 --- a/src/salsa_struct.rs +++ b/src/salsa_struct.rs @@ -1 +1,5 @@ -pub trait SalsaStructInDb {} +use crate::{plumbing::JarAux, IngredientIndex}; + +pub trait SalsaStructInDb { + fn lookup_ingredient_index(aux: &dyn JarAux) -> Option; +} diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 540bb765..8cca21e9 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -1,4 +1,4 @@ -use std::{fmt, hash::Hash, marker::PhantomData, ops::DerefMut}; +use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, ops::DerefMut}; use crossbeam::{atomic::AtomicCell, queue::SegQueue}; use tracked_field::FieldIngredientImpl; @@ -112,6 +112,10 @@ impl Jar for JarImpl { })) .collect() } + + fn salsa_struct_type_id(&self) -> Option { + Some(TypeId::of::<::Struct<'static>>()) + } } pub trait TrackedStructInDb: SalsaStructInDb { @@ -501,7 +505,8 @@ where // and the code that references the memo-table has a read-lock. let memo_table = unsafe { (*data).take_memo_table() }; for (memo_ingredient_index, memo) in memo_table.into_memos() { - let ingredient_index = zalsa.ingredient_index_for_memo(memo_ingredient_index); + let ingredient_index = + zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index); let executor = DatabaseKeyIndex { ingredient_index, diff --git a/src/zalsa.rs b/src/zalsa.rs index 2f8fa95f..b21923d4 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -1,5 +1,5 @@ use append_only_vec::AppendOnlyVec; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use rustc_hash::FxHashMap; use std::any::{Any, TypeId}; use std::marker::PhantomData; @@ -119,8 +119,10 @@ pub struct Zalsa { nonce: Nonce, - /// Number of memo ingredient indices created by calls to [`next_memo_ingredient_index`](`Self::next_memo_ingredient_index`) - memo_ingredients: Mutex>, + /// Map from the [`IngredientIndex::as_usize`][] of a salsa struct to a list of + /// [ingredient-indices](`IngredientIndex`) for tracked functions that have this salsa struct + /// as input. + memo_ingredient_indices: RwLock>>, /// Map from the type-id of an `impl Jar` to the index of its first ingredient. /// This is using a `Mutex` (versus, say, a `FxDashMap`) @@ -152,7 +154,7 @@ impl Zalsa { ingredients_vec: AppendOnlyVec::new(), ingredients_requiring_reset: AppendOnlyVec::new(), runtime: Runtime::default(), - memo_ingredients: Default::default(), + memo_ingredient_indices: Default::default(), } } @@ -186,11 +188,20 @@ impl Zalsa { { let jar_type_id = jar.type_id(); let mut jar_map = self.jar_map.lock(); - *jar_map - .entry(jar_type_id) - .or_insert_with(|| { - let index = IngredientIndex::from(self.ingredients_vec.len()); - let ingredients = jar.create_ingredients(self, index); + let mut should_create = false; + // First record the index we will use into the map and then go and create the ingredients. + // Those ingredients may invoke methods on the `JarAux` trait that read from this map + // to lookup ingredient indices for already created jars. + // + // Note that we still hold the lock above so only one jar is being created at a time and hence + // ingredient indices cannot overlap. + let index = *jar_map.entry(jar_type_id).or_insert_with(|| { + should_create = true; + IngredientIndex::from(self.ingredients_vec.len()) + }); + if should_create { + let aux = JarAuxImpl(self, &jar_map); + let ingredients = jar.create_ingredients(&aux, index); for ingredient in ingredients { let expected_index = ingredient.ingredient_index(); @@ -198,9 +209,7 @@ impl Zalsa { self.ingredients_requiring_reset.push(expected_index); } - let actual_index = self - .ingredients_vec - .push(ingredient); + let actual_index = self.ingredients_vec.push(ingredient); assert_eq!( expected_index.as_usize(), actual_index, @@ -209,10 +218,10 @@ impl Zalsa { expected_index, actual_index, ); - } - index - }) + } + + index } } @@ -290,15 +299,34 @@ impl Zalsa { pub(crate) fn ingredient_index_for_memo( &self, + struct_ingredient_index: IngredientIndex, memo_ingredient_index: MemoIngredientIndex, ) -> IngredientIndex { - self.memo_ingredients.lock()[memo_ingredient_index.as_usize()] + self.memo_ingredient_indices.read()[struct_ingredient_index.as_usize()] + [memo_ingredient_index.as_usize()] } } -impl JarAux for Zalsa { - fn next_memo_ingredient_index(&self, ingredient_index: IngredientIndex) -> MemoIngredientIndex { - let mut memo_ingredients = self.memo_ingredients.lock(); +struct JarAuxImpl<'a>(&'a Zalsa, &'a FxHashMap); + +impl JarAux for JarAuxImpl<'_> { + fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option { + self.1.get(&jar.type_id()).map(ToOwned::to_owned) + } + + fn next_memo_ingredient_index( + &self, + struct_ingredient_index: IngredientIndex, + ingredient_index: IngredientIndex, + ) -> MemoIngredientIndex { + let mut memo_ingredients = self.0.memo_ingredient_indices.write(); + let idx = struct_ingredient_index.as_usize(); + let memo_ingredients = if let Some(memo_ingredients) = memo_ingredients.get_mut(idx) { + memo_ingredients + } else { + memo_ingredients.resize_with(idx + 1, Vec::new); + memo_ingredients.get_mut(idx).unwrap() + }; let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap()); memo_ingredients.push(ingredient_index); mi diff --git a/tests/tracked_fn_multiple_args.rs b/tests/tracked_fn_multiple_args.rs new file mode 100644 index 00000000..7c014356 --- /dev/null +++ b/tests/tracked_fn_multiple_args.rs @@ -0,0 +1,25 @@ +//! Test that a `tracked` fn on multiple salsa struct args +//! compiles and executes successfully. + +#[salsa::input] +struct MyInput { + field: u32, +} + +#[salsa::interned] +struct MyInterned<'db> { + field: u32, +} + +#[salsa::tracked] +fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput, interned: MyInterned<'db>) -> u32 { + input.field(db) + interned.field(db) +} + +#[test] +fn execute() { + let db = salsa::DatabaseImpl::new(); + let input = MyInput::new(&db, 22); + let interned = MyInterned::new(&db, 33); + assert_eq!(tracked_fn(&db, input, interned), 55); +}