diff --git a/crates/turbo-tasks-memory/Cargo.toml b/crates/turbo-tasks-memory/Cargo.toml index ae1acec5a021e..d7a615371236e 100644 --- a/crates/turbo-tasks-memory/Cargo.toml +++ b/crates/turbo-tasks-memory/Cargo.toml @@ -17,6 +17,7 @@ anyhow = { workspace = true } auto-hash-map = { workspace = true } concurrent-queue = { workspace = true } dashmap = { workspace = true } +indexmap = { workspace = true } nohash-hasher = { workspace = true } num_cpus = "1.13.1" once_cell = { workspace = true } @@ -33,8 +34,10 @@ turbo-tasks-malloc = { workspace = true, default-features = false } [dev-dependencies] criterion = { workspace = true, features = ["async_tokio"] } -indexmap = { workspace = true } lazy_static = { workspace = true } +loom = "0.7.2" +rand = { workspace = true, features = ["small_rng"] } +rstest = { workspace = true } serde = { workspace = true } tokio = { workspace = true, features = ["full"] } turbo-tasks-testing = { workspace = true } diff --git a/crates/turbo-tasks-memory/src/aggregation/aggregation_data.rs b/crates/turbo-tasks-memory/src/aggregation/aggregation_data.rs new file mode 100644 index 0000000000000..c14d312b28db8 --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/aggregation_data.rs @@ -0,0 +1,81 @@ +use std::ops::{Deref, DerefMut}; + +use super::{ + increase_aggregation_number_internal, AggregationContext, AggregationNode, AggregationNodeGuard, +}; +use crate::aggregation::balance_queue::BalanceQueue; + +/// Gives an reference to the aggregated data for a given item. This will +/// convert the item to a fully aggregated node. +pub fn aggregation_data<'l, C>( + ctx: &'l C, + node_id: &C::NodeRef, +) -> AggregationDataGuard> +where + C: AggregationContext + 'l, +{ + let guard = ctx.node(node_id); + if guard.aggregation_number() == u32::MAX { + AggregationDataGuard { guard } + } else { + let mut balance_queue = BalanceQueue::new(); + increase_aggregation_number_internal( + ctx, + &mut balance_queue, + guard, + node_id, + u32::MAX, + u32::MAX, + ); + balance_queue.process(ctx); + let guard = ctx.node(node_id); + debug_assert!(guard.aggregation_number() == u32::MAX); + AggregationDataGuard { guard } + } +} + +/// Converted the given node to a fully aggregated node. To make the next call +/// to `aggregation_data` instant. +pub fn prepare_aggregation_data(ctx: &C, node_id: &C::NodeRef) { + let mut balance_queue = BalanceQueue::new(); + increase_aggregation_number_internal( + ctx, + &mut balance_queue, + ctx.node(node_id), + node_id, + u32::MAX, + u32::MAX, + ); + balance_queue.process(ctx); +} + +/// A reference to the aggregated data of a node. This holds a lock to the node. +pub struct AggregationDataGuard { + guard: G, +} + +impl AggregationDataGuard { + pub fn into_inner(self) -> G { + self.guard + } +} + +impl Deref for AggregationDataGuard { + type Target = G::Data; + + fn deref(&self) -> &Self::Target { + match &*self.guard { + AggregationNode::Leaf { .. } => unreachable!(), + AggregationNode::Aggegating(aggregating) => &aggregating.data, + } + } +} + +impl DerefMut for AggregationDataGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + match &mut *self.guard { + AggregationNode::Leaf { .. } => unreachable!(), + AggregationNode::Aggegating(aggregating) => &mut aggregating.data, + } + } +} diff --git a/crates/turbo-tasks-memory/src/aggregation/balance_edge.rs b/crates/turbo-tasks-memory/src/aggregation/balance_edge.rs new file mode 100644 index 0000000000000..fd7a937b72943 --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/balance_edge.rs @@ -0,0 +1,206 @@ +use std::cmp::Ordering; + +use super::{ + balance_queue::BalanceQueue, + followers::{ + add_follower_count, remove_follower_count, remove_positive_follower_count, + RemovePositveFollowerCountResult, + }, + in_progress::is_in_progress, + increase_aggregation_number_internal, + uppers::{ + add_upper_count, remove_positive_upper_count, remove_upper_count, + RemovePositiveUpperCountResult, + }, + AggregationContext, AggregationNode, +}; + +// Migrated followers to uppers or uppers to followers depending on the +// aggregation numbers of the nodes involved in the edge. Might increase targets +// aggregation number if they are equal. +pub(super) fn balance_edge( + ctx: &C, + balance_queue: &mut BalanceQueue, + upper_id: &C::NodeRef, + mut upper_aggregation_number: u32, + target_id: &C::NodeRef, + mut target_aggregation_number: u32, +) -> (u32, u32) { + // too many uppers on target + let mut extra_uppers = 0; + // too many followers on upper + let mut extra_followers = 0; + // The last info about uppers + let mut uppers_count: Option = None; + // The last info about followers + let mut followers_count = None; + + loop { + let root = upper_aggregation_number == u32::MAX || target_aggregation_number == u32::MAX; + let order = if root { + Ordering::Greater + } else { + upper_aggregation_number.cmp(&target_aggregation_number) + }; + match order { + Ordering::Equal => { + // we probably want to increase the aggregation number of target + let upper = ctx.node(upper_id); + upper_aggregation_number = upper.aggregation_number(); + drop(upper); + if upper_aggregation_number != u32::MAX + && upper_aggregation_number == target_aggregation_number + { + let target = ctx.node(target_id); + target_aggregation_number = target.aggregation_number(); + if upper_aggregation_number == target_aggregation_number { + // increase target aggregation number + increase_aggregation_number_internal( + ctx, + balance_queue, + target, + target_id, + target_aggregation_number + 1, + target_aggregation_number + 1, + ); + } + } + } + Ordering::Less => { + // target should probably be a follower of upper + if uppers_count.map_or(false, |count| count <= 0) { + // We already removed all uppers, maybe too many + break; + } else if extra_followers == 0 { + let upper = ctx.node(upper_id); + upper_aggregation_number = upper.aggregation_number(); + if upper_aggregation_number < target_aggregation_number { + // target should be a follower of upper + // add some extra followers + let count = uppers_count.unwrap_or(1) as usize; + extra_followers += count; + followers_count = Some(add_follower_count( + ctx, + balance_queue, + upper, + upper_id, + target_id, + count, + true, + )); + } + } else { + // we already have extra followers, remove some uppers to balance + let count = extra_followers + extra_uppers; + let target = ctx.node(target_id); + if is_in_progress(ctx, upper_id) { + drop(target); + let mut upper = ctx.node(upper_id); + if is_in_progress(ctx, upper_id) { + let AggregationNode::Aggegating(aggregating) = &mut *upper else { + unreachable!(); + }; + aggregating.enqueued_balancing.push(( + upper_id.clone(), + upper_aggregation_number, + target_id.clone(), + target_aggregation_number, + )); + drop(upper); + // Somebody else will balance this edge + return (upper_aggregation_number, target_aggregation_number); + } + } else { + let RemovePositiveUpperCountResult { + removed_count, + remaining_count, + } = remove_positive_upper_count( + ctx, + balance_queue, + target, + upper_id, + count, + ); + decrease_numbers(removed_count, &mut extra_uppers, &mut extra_followers); + uppers_count = Some(remaining_count); + } + } + } + Ordering::Greater => { + // target should probably be an inner node of upper + if followers_count.map_or(false, |count| count <= 0) { + // We already removed all followers, maybe too many + break; + } else if extra_uppers == 0 { + let target = ctx.node(target_id); + target_aggregation_number = target.aggregation_number(); + if root || target_aggregation_number < upper_aggregation_number { + // target should be a inner node of upper + if is_in_progress(ctx, upper_id) { + drop(target); + let mut upper = ctx.node(upper_id); + if is_in_progress(ctx, upper_id) { + let AggregationNode::Aggegating(aggregating) = &mut *upper else { + unreachable!(); + }; + aggregating.enqueued_balancing.push(( + upper_id.clone(), + upper_aggregation_number, + target_id.clone(), + target_aggregation_number, + )); + drop(upper); + // Somebody else will balance this edge + return (upper_aggregation_number, target_aggregation_number); + } + } else { + // add some extra uppers + let count = followers_count.unwrap_or(1) as usize; + extra_uppers += count; + uppers_count = Some( + add_upper_count( + ctx, + balance_queue, + target, + target_id, + upper_id, + count, + true, + ) + .new_count, + ); + } + } + } else { + // we already have extra uppers, try to remove some followers to balance + let count = extra_followers + extra_uppers; + let upper = ctx.node(upper_id); + let RemovePositveFollowerCountResult { + removed_count, + remaining_count, + } = remove_positive_follower_count(ctx, balance_queue, upper, target_id, count); + decrease_numbers(removed_count, &mut extra_followers, &mut extra_uppers); + followers_count = Some(remaining_count); + } + } + } + } + if extra_followers > 0 { + let upper = ctx.node(upper_id); + remove_follower_count(ctx, balance_queue, upper, target_id, extra_followers); + } + if extra_uppers > 0 { + let target = ctx.node(target_id); + remove_upper_count(ctx, balance_queue, target, upper_id, extra_uppers); + } + (upper_aggregation_number, target_aggregation_number) +} + +fn decrease_numbers(amount: usize, a: &mut usize, b: &mut usize) { + if *a >= amount { + *a -= amount; + } else { + *b -= amount - *a; + *a = 0; + } +} diff --git a/crates/turbo-tasks-memory/src/aggregation/balance_queue.rs b/crates/turbo-tasks-memory/src/aggregation/balance_queue.rs new file mode 100644 index 0000000000000..1f11d4dd9a98d --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/balance_queue.rs @@ -0,0 +1,90 @@ +use std::{cmp::max, collections::HashMap, hash::Hash, mem::take}; + +use indexmap::IndexSet; + +use super::{balance_edge, AggregationContext}; + +/// Enqueued edges that need to be balanced. Deduplicates edges and keeps track +/// of aggregation numbers read during balancing. +pub struct BalanceQueue { + queue: IndexSet<(I, I)>, + aggregation_numbers: HashMap, +} + +impl BalanceQueue { + pub fn new() -> Self { + Self { + queue: IndexSet::default(), + aggregation_numbers: HashMap::default(), + } + } + + fn add_number(&mut self, id: I, number: u32) { + self.aggregation_numbers + .entry(id) + .and_modify(|n| *n = max(*n, number)) + .or_insert(number); + } + + /// Add an edge to the queue. The edge will be balanced during the next + /// call. + pub fn balance( + &mut self, + upper_id: I, + upper_aggregation_number: u32, + target_id: I, + target_aggregation_number: u32, + ) { + debug_assert!(upper_id != target_id); + self.add_number(upper_id.clone(), upper_aggregation_number); + self.add_number(target_id.clone(), target_aggregation_number); + self.queue.insert((upper_id.clone(), target_id.clone())); + } + + /// Add multiple edges to the queue. The edges will be balanced during the + /// next call. + pub fn balance_all(&mut self, edges: Vec<(I, u32, I, u32)>) { + for (upper_id, upper_aggregation_number, target_id, target_aggregation_number) in edges { + self.balance( + upper_id, + upper_aggregation_number, + target_id, + target_aggregation_number, + ); + } + } + + /// Process the queue and balance all enqueued edges. + pub fn process>(mut self, ctx: &C) { + while !self.queue.is_empty() { + let queue = take(&mut self.queue); + for (upper_id, target_id) in queue { + let upper_aggregation_number = self + .aggregation_numbers + .get(&upper_id) + .copied() + .unwrap_or_default(); + let target_aggregation_number = self + .aggregation_numbers + .get(&target_id) + .copied() + .unwrap_or_default(); + + let (u, t) = balance_edge( + ctx, + &mut self, + &upper_id, + upper_aggregation_number, + &target_id, + target_aggregation_number, + ); + if u != upper_aggregation_number { + self.add_number(upper_id, u); + } + if t != target_aggregation_number { + self.add_number(target_id, t); + } + } + } + } +} diff --git a/crates/turbo-tasks-memory/src/aggregation/change.rs b/crates/turbo-tasks-memory/src/aggregation/change.rs new file mode 100644 index 0000000000000..a0ff1eb605692 --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/change.rs @@ -0,0 +1,108 @@ +use std::hash::Hash; + +use super::{AggegatingNode, AggregationContext, AggregationNode, PreparedOperation, StackVec}; + +impl AggregationNode { + /// Prepares to apply a change to a node. Changes will be propagated to all + /// upper nodes. + #[must_use] + pub fn apply_change>( + &mut self, + ctx: &C, + change: C::DataChange, + ) -> Option> { + match self { + AggregationNode::Leaf { uppers, .. } => (!uppers.is_empty()).then(|| PreparedChange { + uppers: uppers.iter().cloned().collect::>(), + change, + }), + AggregationNode::Aggegating(aggegating) => { + let AggegatingNode { data, uppers, .. } = &mut **aggegating; + let change = ctx.apply_change(data, &change); + if uppers.is_empty() { + None + } else { + change.map(|change| PreparedChange { + uppers: uppers.iter().cloned().collect::>(), + change, + }) + } + } + } + } + + /// Prepares to apply a change to a node. Changes will be propagated to all + /// upper nodes. + #[must_use] + pub fn apply_change_ref<'l, C: AggregationContext>( + &mut self, + ctx: &C, + change: &'l C::DataChange, + ) -> Option> { + match self { + AggregationNode::Leaf { uppers, .. } => { + (!uppers.is_empty()).then(|| PreparedChangeRef::Borrowed { + uppers: uppers.iter().cloned().collect::>(), + change, + }) + } + AggregationNode::Aggegating(aggegating) => { + let AggegatingNode { data, uppers, .. } = &mut **aggegating; + let change = ctx.apply_change(data, change); + if uppers.is_empty() { + None + } else { + change.map(|change| PreparedChangeRef::Owned { + uppers: uppers.iter().cloned().collect::>(), + change, + }) + } + } + } + } +} + +/// A prepared `apply_change` operation. +pub struct PreparedChange { + uppers: StackVec, + change: C::DataChange, +} + +impl PreparedOperation for PreparedChange { + type Result = (); + fn apply(self, ctx: &C) { + let prepared = self + .uppers + .into_iter() + .filter_map(|upper_id| ctx.node(&upper_id).apply_change_ref(ctx, &self.change)) + .collect::>(); + prepared.apply(ctx); + } +} + +/// A prepared `apply_change_ref` operation. +pub enum PreparedChangeRef<'l, C: AggregationContext> { + Borrowed { + uppers: StackVec, + change: &'l C::DataChange, + }, + Owned { + uppers: StackVec, + change: C::DataChange, + }, +} + +impl<'l, C: AggregationContext> PreparedOperation for PreparedChangeRef<'l, C> { + type Result = (); + fn apply(self, ctx: &C) { + let (uppers, change) = match self { + PreparedChangeRef::Borrowed { uppers, change } => (uppers, change), + PreparedChangeRef::Owned { uppers, ref change } => (uppers, change), + }; + let prepared = uppers + .into_iter() + .filter_map(|upper_id| ctx.node(&upper_id).apply_change_ref(ctx, change)) + .collect::>(); + prepared.apply(ctx); + } +} diff --git a/crates/turbo-tasks-memory/src/aggregation/followers.rs b/crates/turbo-tasks-memory/src/aggregation/followers.rs new file mode 100644 index 0000000000000..f9f2f410ee125 --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/followers.rs @@ -0,0 +1,201 @@ +use super::{ + balance_queue::BalanceQueue, + in_progress::start_in_progress_all, + notify_lost_follower, notify_new_follower, + optimize::{optimize_aggregation_number_for_followers, MAX_FOLLOWERS}, + AggregationContext, AggregationNode, StackVec, +}; +use crate::count_hash_set::RemovePositiveCountResult; + +/// Add a follower to a node. Followers will be propagated to the uppers of the +/// node. +pub fn add_follower( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut node: C::Guard<'_>, + node_id: &C::NodeRef, + follower_id: &C::NodeRef, + already_optimizing_for_node: bool, +) -> usize { + let AggregationNode::Aggegating(aggregating) = &mut *node else { + unreachable!(); + }; + if aggregating.followers.add_clonable(follower_id) { + on_added( + ctx, + balance_queue, + node, + node_id, + follower_id, + already_optimizing_for_node, + ) + } else { + 0 + } +} + +/// Handle the addition of a follower to a node. This function is called after +/// the follower has been added to the node. +pub fn on_added( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut node: C::Guard<'_>, + node_id: &C::NodeRef, + follower_id: &C::NodeRef, + already_optimizing_for_node: bool, +) -> usize { + let AggregationNode::Aggegating(aggregating) = &mut *node else { + unreachable!(); + }; + let followers_len = aggregating.followers.len(); + let optimize = (!already_optimizing_for_node + && followers_len > MAX_FOLLOWERS + && (followers_len - MAX_FOLLOWERS).count_ones() == 1) + .then(|| { + aggregating + .followers + .iter() + .cloned() + .collect::>() + }); + let uppers = aggregating.uppers.iter().cloned().collect::>(); + start_in_progress_all(ctx, &uppers); + drop(node); + + let mut optimizing = false; + + if let Some(followers) = optimize { + optimizing = optimize_aggregation_number_for_followers( + ctx, + balance_queue, + node_id, + followers, + false, + ); + } + + let mut affected_nodes = uppers.len(); + for upper_id in uppers { + affected_nodes += notify_new_follower( + ctx, + balance_queue, + ctx.node(&upper_id), + &upper_id, + follower_id, + optimizing, + ); + } + affected_nodes +} + +/// Add a follower to a node with a count. Followers will be propagated to the +/// uppers of the node. +pub fn add_follower_count( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut node: C::Guard<'_>, + node_id: &C::NodeRef, + follower_id: &C::NodeRef, + follower_count: usize, + already_optimizing_for_node: bool, +) -> isize { + let AggregationNode::Aggegating(aggregating) = &mut *node else { + unreachable!(); + }; + if aggregating + .followers + .add_clonable_count(follower_id, follower_count) + { + let count = aggregating.followers.get_count(follower_id); + on_added( + ctx, + balance_queue, + node, + node_id, + follower_id, + already_optimizing_for_node, + ); + count + } else { + aggregating.followers.get_count(follower_id) + } +} + +/// Remove a follower from a node. Followers will be propagated to the uppers of +/// the node. +pub fn remove_follower_count( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut node: C::Guard<'_>, + follower_id: &C::NodeRef, + follower_count: usize, +) { + let AggregationNode::Aggegating(aggregating) = &mut *node else { + unreachable!(); + }; + if aggregating + .followers + .remove_clonable_count(follower_id, follower_count) + { + let uppers = aggregating.uppers.iter().cloned().collect::>(); + start_in_progress_all(ctx, &uppers); + drop(node); + for upper_id in uppers { + notify_lost_follower( + ctx, + balance_queue, + ctx.node(&upper_id), + &upper_id, + follower_id, + ); + } + } +} + +pub struct RemovePositveFollowerCountResult { + /// The amount of followers that have been removed. + pub removed_count: usize, + /// The amount of followers that are remaining. Might be negative. + pub remaining_count: isize, +} + +/// Remove a positive count of a follower from a node. Negative counts will not +/// be increased. The function returns how much of the count has been removed +/// and whats remaining. Followers will be propagated to the uppers of the node. +pub fn remove_positive_follower_count( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut node: C::Guard<'_>, + follower_id: &C::NodeRef, + follower_count: usize, +) -> RemovePositveFollowerCountResult { + let AggregationNode::Aggegating(aggregating) = &mut *node else { + unreachable!(); + }; + let RemovePositiveCountResult { + removed, + removed_count, + count, + } = aggregating + .followers + .remove_positive_clonable_count(follower_id, follower_count); + + if removed { + let uppers = aggregating.uppers.iter().cloned().collect::>(); + start_in_progress_all(ctx, &uppers); + drop(node); + for upper_id in uppers { + notify_lost_follower( + ctx, + balance_queue, + ctx.node(&upper_id), + &upper_id, + follower_id, + ); + } + } + RemovePositveFollowerCountResult { + removed_count, + remaining_count: count, + } +} diff --git a/crates/turbo-tasks-memory/src/aggregation/in_progress.rs b/crates/turbo-tasks-memory/src/aggregation/in_progress.rs new file mode 100644 index 0000000000000..1dbb080630c3e --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/in_progress.rs @@ -0,0 +1,75 @@ +use std::{hash::Hash, mem::take}; + +use super::{balance_queue::BalanceQueue, AggregationContext, AggregationNode, StackVec}; + +impl AggregationNode { + /// Finishes an in progress operation. This might enqueue balancing + /// operations when they weren't possible due to the in progress operation. + pub(super) fn finish_in_progress>( + &mut self, + ctx: &C, + balance_queue: &mut BalanceQueue, + node_id: &I, + ) { + let value = ctx + .atomic_in_progress_counter(node_id) + .fetch_sub(1, std::sync::atomic::Ordering::AcqRel); + debug_assert!(value > 0); + if value == 1 { + if let AggregationNode::Aggegating(aggegating) = &mut *self { + balance_queue.balance_all(take(&mut aggegating.enqueued_balancing)) + } + } + } +} + +/// Finishes an in progress operation. This might enqueue balancing +/// operations when they weren't possible due to the in progress operation. +/// This version doesn't require a node guard. +pub fn finish_in_progress_without_node( + ctx: &C, + balance_queue: &mut BalanceQueue, + node_id: &C::NodeRef, +) { + let value = ctx + .atomic_in_progress_counter(node_id) + .fetch_sub(1, std::sync::atomic::Ordering::AcqRel); + debug_assert!(value > 0); + if value == 1 { + let mut node = ctx.node(node_id); + if let AggregationNode::Aggegating(aggegating) = &mut *node { + balance_queue.balance_all(take(&mut aggegating.enqueued_balancing)) + } + } +} + +/// Starts an in progress operation for all nodes in the list. +pub fn start_in_progress_all(ctx: &C, node_ids: &StackVec) { + for node_id in node_ids { + start_in_progress(ctx, node_id); + } +} + +/// Starts an in progress operation for a node. +pub fn start_in_progress(ctx: &C, node_id: &C::NodeRef) { + start_in_progress_count(ctx, node_id, 1); +} + +/// Starts multiple in progress operations for a node. +pub fn start_in_progress_count(ctx: &C, node_id: &C::NodeRef, count: u32) { + if count == 0 { + return; + } + ctx.atomic_in_progress_counter(node_id) + .fetch_add(count, std::sync::atomic::Ordering::Release); +} + +/// Checks if there is an in progress operation for a node. +/// It doesn't require a lock, but should run under a lock of the node or a +/// follower/inner node. +pub fn is_in_progress(ctx: &C, node_id: &C::NodeRef) -> bool { + let counter = ctx + .atomic_in_progress_counter(node_id) + .load(std::sync::atomic::Ordering::Acquire); + counter > 0 +} diff --git a/crates/turbo-tasks-memory/src/aggregation/increase.rs b/crates/turbo-tasks-memory/src/aggregation/increase.rs new file mode 100644 index 0000000000000..879cc37cd7e02 --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/increase.rs @@ -0,0 +1,311 @@ +use std::{hash::Hash, mem::take}; + +use super::{ + balance_queue::BalanceQueue, AggegatingNode, AggregationContext, AggregationNode, + AggregationNodeGuard, PreparedInternalOperation, PreparedOperation, StackVec, +}; +pub(super) const LEAF_NUMBER: u32 = 64; + +impl AggregationNode { + /// Increase the aggregation number of a node. This might temporarily + /// violate the graph invariants between uppers and followers of that node. + /// Therefore a balancing operation is enqueued to restore the invariants. + /// The actual change to the aggregation number is applied in the prepared + /// operation after checking all upper nodes aggregation numbers. + #[must_use] + pub(super) fn increase_aggregation_number_internal< + C: AggregationContext, + >( + &mut self, + _ctx: &C, + node_id: &C::NodeRef, + min_aggregation_number: u32, + target_aggregation_number: u32, + ) -> Option> { + if self.aggregation_number() >= min_aggregation_number { + return None; + } + Some(PreparedInternalIncreaseAggregationNumber::Lazy { + node_id: node_id.clone(), + uppers: self.uppers_mut().iter().cloned().collect(), + min_aggregation_number, + target_aggregation_number, + }) + } + + /// Increase the aggregation number of a node. This is only for testing + /// proposes. + #[cfg(test)] + pub fn increase_aggregation_number>( + &mut self, + _ctx: &C, + node_id: &C::NodeRef, + new_aggregation_number: u32, + ) -> Option> { + self.increase_aggregation_number_internal( + _ctx, + node_id, + new_aggregation_number, + new_aggregation_number, + ) + .map(PreparedIncreaseAggregationNumber) + } +} + +/// Increase the aggregation number of a node directly. This might temporarily +/// violate the graph invariants between uppers and followers of that node. +/// Therefore a balancing operation is enqueued to restore the invariants. +/// The actual change to the aggregation number is applied directly without +/// checking the upper nodes. +#[must_use] +pub(super) fn increase_aggregation_number_immediately( + _ctx: &C, + node: &mut C::Guard<'_>, + node_id: C::NodeRef, + min_aggregation_number: u32, + target_aggregation_number: u32, +) -> Option> { + if node.aggregation_number() >= min_aggregation_number { + return None; + } + let children = matches!(**node, AggregationNode::Leaf { .. }) + .then(|| node.children().collect::>()); + match &mut **node { + AggregationNode::Leaf { + aggregation_number, + uppers, + } => { + let children = children.unwrap(); + if target_aggregation_number < LEAF_NUMBER { + *aggregation_number = target_aggregation_number as u8; + Some(PreparedInternalIncreaseAggregationNumber::Leaf { + target_aggregation_number, + children, + }) + } else { + let uppers_copy = uppers.iter().cloned().collect::>(); + // Convert to Aggregating + **node = AggregationNode::Aggegating(Box::new(AggegatingNode { + aggregation_number: target_aggregation_number, + uppers: take(uppers), + followers: children.iter().cloned().collect(), + data: node.get_initial_data(), + enqueued_balancing: Vec::new(), + })); + let followers = children; + Some(PreparedInternalIncreaseAggregationNumber::Aggregating { + node_id, + uppers: uppers_copy, + followers, + target_aggregation_number, + }) + } + } + AggregationNode::Aggegating(aggegating) => { + let AggegatingNode { + followers, + uppers, + aggregation_number, + .. + } = &mut **aggegating; + let uppers = uppers.iter().cloned().collect::>(); + let followers = followers.iter().cloned().collect(); + *aggregation_number = target_aggregation_number; + Some(PreparedInternalIncreaseAggregationNumber::Aggregating { + node_id, + uppers, + followers, + target_aggregation_number, + }) + } + } +} + +/// A prepared `increase_aggregation_number` operation. +pub enum PreparedInternalIncreaseAggregationNumber { + Lazy { + node_id: C::NodeRef, + uppers: StackVec, + min_aggregation_number: u32, + target_aggregation_number: u32, + }, + Leaf { + children: StackVec, + target_aggregation_number: u32, + }, + Aggregating { + node_id: C::NodeRef, + uppers: StackVec, + followers: StackVec, + target_aggregation_number: u32, + }, +} + +impl PreparedInternalOperation + for PreparedInternalIncreaseAggregationNumber +{ + type Result = (); + fn apply(self, ctx: &C, balance_queue: &mut BalanceQueue) { + match self { + PreparedInternalIncreaseAggregationNumber::Lazy { + min_aggregation_number, + mut target_aggregation_number, + node_id, + uppers, + } => { + let mut need_to_run = true; + while need_to_run { + need_to_run = false; + let mut max = 0; + for upper_id in &uppers { + let upper = ctx.node(upper_id); + let aggregation_number = upper.aggregation_number(); + if aggregation_number != u32::MAX { + if aggregation_number > max { + max = aggregation_number; + } + if aggregation_number == target_aggregation_number { + target_aggregation_number += 1; + if max >= target_aggregation_number { + need_to_run = true; + } + } + } + } + } + drop(uppers); + let mut node = ctx.node(&node_id); + if node.aggregation_number() >= min_aggregation_number { + return; + } + let children = matches!(*node, AggregationNode::Leaf { .. }) + .then(|| node.children().collect::>()); + let (uppers, followers) = match &mut *node { + AggregationNode::Leaf { + aggregation_number, + uppers, + } => { + let children = children.unwrap(); + if target_aggregation_number < LEAF_NUMBER { + *aggregation_number = target_aggregation_number as u8; + drop(node); + for child_id in children { + increase_aggregation_number_internal( + ctx, + balance_queue, + ctx.node(&child_id), + &child_id, + target_aggregation_number + 1, + target_aggregation_number + 1, + ); + } + return; + } else { + let uppers_copy = uppers.iter().cloned().collect::>(); + // Convert to Aggregating + *node = AggregationNode::Aggegating(Box::new(AggegatingNode { + aggregation_number: target_aggregation_number, + uppers: take(uppers), + followers: children.iter().cloned().collect(), + data: node.get_initial_data(), + enqueued_balancing: Vec::new(), + })); + let followers = children; + drop(node); + (uppers_copy, followers) + } + } + AggregationNode::Aggegating(aggegating) => { + let AggegatingNode { + followers, + uppers, + aggregation_number, + .. + } = &mut **aggegating; + let uppers = uppers.iter().cloned().collect::>(); + let followers = followers.iter().cloned().collect(); + *aggregation_number = target_aggregation_number; + drop(node); + (uppers, followers) + } + }; + for follower_id in followers { + balance_queue.balance( + node_id.clone(), + target_aggregation_number, + follower_id, + 0, + ); + } + for upper_id in uppers { + balance_queue.balance(upper_id, 0, node_id.clone(), target_aggregation_number); + } + } + PreparedInternalIncreaseAggregationNumber::Leaf { + children, + target_aggregation_number, + } => { + for child_id in children { + increase_aggregation_number_internal( + ctx, + balance_queue, + ctx.node(&child_id), + &child_id, + target_aggregation_number + 1, + target_aggregation_number + 1, + ); + } + } + PreparedInternalIncreaseAggregationNumber::Aggregating { + node_id, + uppers, + followers, + target_aggregation_number, + } => { + for follower_id in followers { + balance_queue.balance( + node_id.clone(), + target_aggregation_number, + follower_id, + 0, + ); + } + for upper_id in uppers { + balance_queue.balance(upper_id, 0, node_id.clone(), target_aggregation_number); + } + } + } + } +} + +pub fn increase_aggregation_number_internal( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut node: C::Guard<'_>, + node_id: &C::NodeRef, + min_aggregation_number: u32, + target_aggregation_number: u32, +) { + let prepared = node.increase_aggregation_number_internal( + ctx, + node_id, + min_aggregation_number, + target_aggregation_number, + ); + drop(node); + prepared.apply(ctx, balance_queue); +} + +/// A prepared `increase_aggregation_number` operation. +pub struct PreparedIncreaseAggregationNumber( + PreparedInternalIncreaseAggregationNumber, +); + +impl PreparedOperation for PreparedIncreaseAggregationNumber { + type Result = (); + fn apply(self, ctx: &C) { + let mut balance_queue = BalanceQueue::new(); + self.0.apply(ctx, &mut balance_queue); + balance_queue.process(ctx); + } +} diff --git a/crates/turbo-tasks-memory/src/aggregation/loom_tests.rs b/crates/turbo-tasks-memory/src/aggregation/loom_tests.rs new file mode 100644 index 0000000000000..3958f1e966414 --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/loom_tests.rs @@ -0,0 +1,268 @@ +use std::{ + fmt::Debug, + hash::Hash, + ops::{Deref, DerefMut}, + sync::{atomic::AtomicU32, Arc}, +}; + +use loom::{ + sync::{Mutex, MutexGuard}, + thread, +}; +use nohash_hasher::IsEnabled; +use rand::{rngs::SmallRng, Rng, SeedableRng}; +use ref_cast::RefCast; +use rstest::*; + +use super::{ + aggregation_data, handle_new_edge, AggregationContext, AggregationNode, AggregationNodeGuard, + PreparedOperation, +}; + +struct Node { + atomic: AtomicU32, + inner: Mutex, +} + +impl Node { + fn new(value: u32) -> Arc { + Arc::new(Node { + atomic: AtomicU32::new(0), + inner: Mutex::new(NodeInner { + children: Vec::new(), + aggregation_node: AggregationNode::new(), + value, + }), + }) + } + + fn add_child(self: &Arc, aggregation_context: &NodeAggregationContext, child: Arc) { + let mut guard = self.inner.lock().unwrap(); + guard.children.push(child.clone()); + let number_of_children = guard.children.len(); + let mut guard = unsafe { NodeGuard::new(guard, self.clone()) }; + let prepared = handle_new_edge( + aggregation_context, + &mut guard, + &NodeRef(self.clone()), + &NodeRef(child), + number_of_children, + ); + drop(guard); + prepared.apply(aggregation_context); + } +} + +#[derive(Copy, Clone)] +struct Change {} + +struct NodeInner { + children: Vec>, + aggregation_node: AggregationNode, + value: u32, +} + +struct NodeAggregationContext {} + +#[derive(Clone, RefCast)] +#[repr(transparent)] +struct NodeRef(Arc); + +impl Debug for NodeRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NodeRef({})", self.0.inner.lock().unwrap().value) + } +} + +impl Hash for NodeRef { + fn hash(&self, state: &mut H) { + Arc::as_ptr(&self.0).hash(state); + } +} + +impl IsEnabled for NodeRef {} + +impl PartialEq for NodeRef { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl Eq for NodeRef {} + +struct NodeGuard { + guard: MutexGuard<'static, NodeInner>, + // This field is important to keep the node alive + #[allow(dead_code)] + node: Arc, +} + +impl NodeGuard { + unsafe fn new(guard: MutexGuard<'_, NodeInner>, node: Arc) -> Self { + NodeGuard { + guard: unsafe { std::mem::transmute(guard) }, + node, + } + } +} + +impl Deref for NodeGuard { + type Target = AggregationNode; + + fn deref(&self) -> &Self::Target { + &self.guard.aggregation_node + } +} + +impl DerefMut for NodeGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.guard.aggregation_node + } +} + +impl AggregationNodeGuard for NodeGuard { + type Data = Aggregated; + type NodeRef = NodeRef; + type DataChange = Change; + type ChildrenIter<'a> = impl Iterator + 'a; + + fn children(&self) -> Self::ChildrenIter<'_> { + self.guard + .children + .iter() + .map(|child| NodeRef(child.clone())) + } + + fn get_remove_change(&self) -> Option { + None + } + + fn get_add_change(&self) -> Option { + None + } + + fn get_initial_data(&self) -> Self::Data { + Aggregated {} + } +} + +impl AggregationContext for NodeAggregationContext { + type Guard<'l> = NodeGuard where Self: 'l; + type Data = Aggregated; + type NodeRef = NodeRef; + type DataChange = Change; + + fn node<'b>(&'b self, reference: &Self::NodeRef) -> Self::Guard<'b> { + let r = reference.0.clone(); + let guard = reference.0.inner.lock().unwrap(); + unsafe { NodeGuard::new(guard, r) } + } + + fn atomic_in_progress_counter<'l>(&self, id: &'l NodeRef) -> &'l AtomicU32 + where + Self: 'l, + { + &id.0.atomic + } + + fn apply_change(&self, _data: &mut Aggregated, _change: &Change) -> Option { + None + } + + fn data_to_add_change(&self, _data: &Self::Data) -> Option { + None + } + + fn data_to_remove_change(&self, _data: &Self::Data) -> Option { + None + } +} + +#[derive(Default)] +struct Aggregated {} + +// #[test] +#[allow(dead_code)] +fn fuzzy_loom_new() { + for size in [10, 20] { + for _ in 0..1000 { + let seed = rand::random(); + println!("Seed {} Size {}", seed, size); + fuzzy_loom(seed, size); + } + } +} + +#[rstest] +#[case::a(3302552607, 10)] +// #[case::b(3629477471, 50)] +// #[case::c(1006976052, 20)] +// #[case::d(2174645157, 10)] +fn fuzzy_loom(#[case] seed: u32, #[case] count: u32) { + let mut builder = loom::model::Builder::new(); + builder.max_branches = 100000; + builder.check(move || { + loom::stop_exploring(); + thread::Builder::new() + .stack_size(80000) + .spawn(move || { + let ctx = NodeAggregationContext {}; + + let mut seed_buffer = [0; 32]; + seed_buffer[0..4].copy_from_slice(&seed.to_be_bytes()); + let mut r = SmallRng::from_seed(seed_buffer); + let mut nodes = Vec::new(); + for i in 0..count { + nodes.push(Node::new(i)); + } + aggregation_data(&ctx, &NodeRef(nodes[0].clone())); + aggregation_data(&ctx, &NodeRef(nodes[1].clone())); + + // setup graph + for _ in 0..20 { + let parent = r.gen_range(0..nodes.len() - 1); + let child = r.gen_range(parent + 1..nodes.len()); + let parent_node = nodes[parent].clone(); + let child_node = nodes[child].clone(); + parent_node.add_child(&ctx, child_node); + } + + let mut edges = Vec::new(); + for _ in 0..2 { + let parent = r.gen_range(0..nodes.len() - 1); + let child = r.gen_range(parent + 1..nodes.len()); + let parent_node = nodes[parent].clone(); + let child_node = nodes[child].clone(); + edges.push((parent_node, child_node)); + } + + let ctx = Arc::new(ctx); + + loom::explore(); + + let mut threads = Vec::new(); + + // Fancy testing + for (parent_node, child_node) in edges.iter() { + let parent_node = parent_node.clone(); + let child_node = child_node.clone(); + let ctx = ctx.clone(); + threads.push( + thread::Builder::new() + .stack_size(80000) + .spawn(move || { + parent_node.add_child(&ctx, child_node); + }) + .unwrap(), + ); + } + + for thread in threads { + thread.join().unwrap(); + } + }) + .unwrap() + .join() + .unwrap(); + }); +} diff --git a/crates/turbo-tasks-memory/src/aggregation/lost_edge.rs b/crates/turbo-tasks-memory/src/aggregation/lost_edge.rs new file mode 100644 index 0000000000000..7c95e3cbb461c --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/lost_edge.rs @@ -0,0 +1,115 @@ +use std::hash::Hash; + +use super::{ + balance_queue::BalanceQueue, in_progress::start_in_progress_count, + notify_lost_follower::PreparedNotifyLostFollower, AggregationContext, AggregationNode, + PreparedInternalOperation, PreparedOperation, StackVec, +}; + +impl AggregationNode { + /// Handles the loss of edges to a node. This will notify all upper nodes + /// about the new follower or add the new node as inner node. + #[must_use] + pub fn handle_lost_edges>( + &mut self, + ctx: &C, + origin_id: &C::NodeRef, + target_ids: impl IntoIterator, + ) -> Option> { + match self { + AggregationNode::Leaf { uppers, .. } => { + let uppers = uppers.iter().cloned().collect::>(); + let target_ids: StackVec<_> = target_ids.into_iter().collect(); + for upper_id in &uppers { + start_in_progress_count(ctx, upper_id, target_ids.len() as u32); + } + Some(PreparedLostEdgesInner::Leaf { uppers, target_ids }.into()) + } + AggregationNode::Aggegating(_) => { + let notify = target_ids + .into_iter() + .filter_map(|target_id| { + self.notify_lost_follower_not_in_progress(ctx, origin_id, &target_id) + }) + .collect::>(); + (!notify.is_empty()).then(|| notify.into()) + } + } + } +} + +/// A prepared `handle_lost_edges` operation. +pub struct PreparedLostEdges { + inner: PreparedLostEdgesInner, +} + +impl From> for PreparedLostEdges { + fn from(inner: PreparedLostEdgesInner) -> Self { + Self { inner } + } +} + +impl From>> for PreparedLostEdges { + fn from(notify: StackVec>) -> Self { + Self { + inner: PreparedLostEdgesInner::Aggregating { notify }, + } + } +} + +#[allow(clippy::large_enum_variant)] +enum PreparedLostEdgesInner { + Leaf { + uppers: StackVec, + target_ids: StackVec, + }, + Aggregating { + notify: StackVec>, + }, +} + +impl PreparedOperation for PreparedLostEdges { + type Result = (); + fn apply(self, ctx: &C) { + let mut balance_queue = BalanceQueue::new(); + match self.inner { + PreparedLostEdgesInner::Leaf { uppers, target_ids } => { + // TODO This could be more efficient + for upper_id in uppers { + let mut upper = ctx.node(&upper_id); + let prepared = target_ids + .iter() + .filter_map(|target_id| { + upper.notify_lost_follower( + ctx, + &mut balance_queue, + &upper_id, + target_id, + ) + }) + .collect::>(); + drop(upper); + prepared.apply(ctx, &mut balance_queue); + } + } + PreparedLostEdgesInner::Aggregating { notify } => { + notify.apply(ctx, &mut balance_queue); + } + } + balance_queue.process(ctx); + } +} + +/// Handles the loss of edges to a node. This will notify all upper nodes +/// about the new follower or add the new node as inner node. +#[cfg(test)] +pub fn handle_lost_edges( + ctx: &C, + mut origin: C::Guard<'_>, + origin_id: &C::NodeRef, + target_ids: impl IntoIterator, +) { + let p = origin.handle_lost_edges(ctx, origin_id, target_ids); + drop(origin); + p.apply(ctx); +} diff --git a/crates/turbo-tasks-memory/src/aggregation/mod.rs b/crates/turbo-tasks-memory/src/aggregation/mod.rs new file mode 100644 index 0000000000000..6658cedcb897e --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/mod.rs @@ -0,0 +1,235 @@ +use std::{fmt::Debug, hash::Hash, ops::DerefMut, sync::atomic::AtomicU32}; + +use smallvec::SmallVec; + +use crate::count_hash_set::CountHashSet; + +mod aggregation_data; +mod balance_edge; +mod balance_queue; +mod change; +mod followers; +mod in_progress; +mod increase; +#[cfg(test)] +mod loom_tests; +mod lost_edge; +mod new_edge; +mod notify_lost_follower; +mod notify_new_follower; +mod optimize; +mod root_query; +#[cfg(test)] +mod tests; +mod uppers; + +pub use aggregation_data::{aggregation_data, prepare_aggregation_data, AggregationDataGuard}; +use balance_edge::balance_edge; +use increase::increase_aggregation_number_internal; +pub use new_edge::handle_new_edge; +use notify_lost_follower::notify_lost_follower; +use notify_new_follower::notify_new_follower; +pub use root_query::{query_root_info, RootQuery}; + +use self::balance_queue::BalanceQueue; + +type StackVec = SmallVec<[I; 16]>; + +/// The aggregation node structure. This stores the aggregation number, the +/// aggregation edges to uppers and followers and the aggregated data. +pub enum AggregationNode { + Leaf { + aggregation_number: u8, + uppers: CountHashSet, + }, + Aggegating(Box>), +} + +impl AggregationNode { + pub fn new() -> Self { + Self::Leaf { + aggregation_number: 0, + uppers: CountHashSet::new(), + } + } +} + +/// The aggregation node structure for aggregating nodes. +pub struct AggegatingNode { + aggregation_number: u32, + uppers: CountHashSet, + followers: CountHashSet, + data: D, + enqueued_balancing: Vec<(I, u32, I, u32)>, +} + +impl AggregationNode { + /// Returns the aggregation number of the node. + pub fn aggregation_number(&self) -> u32 { + match self { + AggregationNode::Leaf { + aggregation_number, .. + } => *aggregation_number as u32, + AggregationNode::Aggegating(aggegating) => aggegating.aggregation_number, + } + } + + fn is_leaf(&self) -> bool { + matches!(self, AggregationNode::Leaf { .. }) + } + + fn uppers(&self) -> &CountHashSet { + match self { + AggregationNode::Leaf { uppers, .. } => uppers, + AggregationNode::Aggegating(aggegating) => &aggegating.uppers, + } + } + + fn uppers_mut(&mut self) -> &mut CountHashSet { + match self { + AggregationNode::Leaf { uppers, .. } => uppers, + AggregationNode::Aggegating(aggegating) => &mut aggegating.uppers, + } + } + + fn followers(&self) -> Option<&CountHashSet> { + match self { + AggregationNode::Leaf { .. } => None, + AggregationNode::Aggegating(aggegating) => Some(&aggegating.followers), + } + } +} + +/// A prepared operation. Must be applied outside of node locks. +#[must_use] +pub trait PreparedOperation { + type Result; + fn apply(self, ctx: &C) -> Self::Result; +} + +impl> PreparedOperation for Option { + type Result = Option; + fn apply(self, ctx: &C) -> Self::Result { + self.map(|prepared| prepared.apply(ctx)) + } +} + +impl> PreparedOperation for Vec { + type Result = (); + fn apply(self, ctx: &C) -> Self::Result { + for prepared in self { + prepared.apply(ctx); + } + } +} + +impl, const N: usize> PreparedOperation + for SmallVec<[T; N]> +{ + type Result = (); + fn apply(self, ctx: &C) -> Self::Result { + for prepared in self { + prepared.apply(ctx); + } + } +} + +/// A prepared internal operation. Must be applied inside of node locks and with +/// a balance queue. +#[must_use] +trait PreparedInternalOperation { + type Result; + fn apply(self, ctx: &C, balance_queue: &mut BalanceQueue) -> Self::Result; +} + +impl> PreparedInternalOperation + for Option +{ + type Result = Option; + fn apply(self, ctx: &C, balance_queue: &mut BalanceQueue) -> Self::Result { + self.map(|prepared| prepared.apply(ctx, balance_queue)) + } +} + +impl> PreparedInternalOperation + for Vec +{ + type Result = (); + fn apply(self, ctx: &C, balance_queue: &mut BalanceQueue) -> Self::Result { + for prepared in self { + prepared.apply(ctx, balance_queue); + } + } +} + +impl, const N: usize> + PreparedInternalOperation for SmallVec<[T; N]> +{ + type Result = (); + fn apply(self, ctx: &C, balance_queue: &mut BalanceQueue) -> Self::Result { + for prepared in self { + prepared.apply(ctx, balance_queue); + } + } +} + +/// Context for aggregation operations. +pub trait AggregationContext { + type NodeRef: Clone + Eq + Hash + Debug; + type Guard<'l>: AggregationNodeGuard< + NodeRef = Self::NodeRef, + Data = Self::Data, + DataChange = Self::DataChange, + > + where + Self: 'l; + type Data; + type DataChange; + + /// Gets mutable access to an item. + fn node<'l>(&'l self, id: &Self::NodeRef) -> Self::Guard<'l>; + + /// Get the atomic in progress counter for a node. + fn atomic_in_progress_counter<'l>(&self, id: &'l Self::NodeRef) -> &'l AtomicU32 + where + Self: 'l; + + /// Apply a changeset to an aggregated data object. Returns a new changeset + /// that should be applied to the next aggregation level. Might return None, + /// if no change should be applied to the next level. + fn apply_change( + &self, + data: &mut Self::Data, + change: &Self::DataChange, + ) -> Option; + + /// Creates a changeset from an aggregated data object, that represents + /// adding the aggregated node to an aggregated node of the next level. + fn data_to_add_change(&self, data: &Self::Data) -> Option; + /// Creates a changeset from an aggregated data object, that represents + /// removing the aggregated node from an aggregated node of the next level. + fn data_to_remove_change(&self, data: &Self::Data) -> Option; +} + +/// A guard for a node that allows to access the aggregation node, children and +/// data. +pub trait AggregationNodeGuard: + DerefMut> +{ + type NodeRef: Clone + Eq + Hash; + type Data; + type DataChange; + + type ChildrenIter<'a>: Iterator + 'a + where + Self: 'a; + + /// Returns an iterator over the children. + fn children(&self) -> Self::ChildrenIter<'_>; + /// Returns a changeset that represents the addition of the node. + fn get_add_change(&self) -> Option; + /// Returns a changeset that represents the removal of the node. + fn get_remove_change(&self) -> Option; + /// Returns the aggregated data which contains only that node + fn get_initial_data(&self) -> Self::Data; +} diff --git a/crates/turbo-tasks-memory/src/aggregation/new_edge.rs b/crates/turbo-tasks-memory/src/aggregation/new_edge.rs new file mode 100644 index 0000000000000..1c295434a2316 --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/new_edge.rs @@ -0,0 +1,179 @@ +use super::{ + balance_queue::BalanceQueue, + in_progress::start_in_progress_all, + increase::{ + increase_aggregation_number_immediately, PreparedInternalIncreaseAggregationNumber, + LEAF_NUMBER, + }, + increase_aggregation_number_internal, notify_new_follower, + notify_new_follower::PreparedNotifyNewFollower, + optimize::optimize_aggregation_number_for_uppers, + AggregationContext, AggregationNode, PreparedInternalOperation, PreparedOperation, StackVec, +}; + +const BUFFER_SPACE: u32 = 2; + +const MAX_UPPERS_TIMES_CHILDREN: usize = 32; + +const MAX_AFFECTED_NODES: usize = 4096; + +/// Handle the addition of a new edge to a node. The the edge is propagated to +/// the uppers of that node or added a inner node. +#[tracing::instrument(level = tracing::Level::TRACE, name = "handle_new_edge_preparation", skip_all)] +pub fn handle_new_edge<'l, C: AggregationContext>( + ctx: &C, + origin: &mut C::Guard<'l>, + origin_id: &C::NodeRef, + target_id: &C::NodeRef, + number_of_children: usize, +) -> impl PreparedOperation { + match **origin { + AggregationNode::Leaf { + ref mut aggregation_number, + ref uppers, + } => { + if number_of_children.count_ones() == 1 + && (uppers.len() + 1) * number_of_children >= MAX_UPPERS_TIMES_CHILDREN + { + let uppers = uppers.iter().cloned().collect::>(); + start_in_progress_all(ctx, &uppers); + let increase = increase_aggregation_number_immediately( + ctx, + origin, + origin_id.clone(), + LEAF_NUMBER, + LEAF_NUMBER, + ) + .unwrap(); + Some(PreparedNewEdge::Upgraded { + uppers, + target_id: target_id.clone(), + increase, + }) + } else { + let min_aggregation_number = *aggregation_number as u32 + 1; + let target_aggregation_number = *aggregation_number as u32 + 1 + BUFFER_SPACE; + let uppers = uppers.iter().cloned().collect::>(); + start_in_progress_all(ctx, &uppers); + Some(PreparedNewEdge::Leaf { + min_aggregation_number, + target_aggregation_number, + uppers, + target_id: target_id.clone(), + }) + } + } + AggregationNode::Aggegating(_) => origin + .notify_new_follower_not_in_progress(ctx, origin_id, target_id) + .map(|notify| PreparedNewEdge::Aggegating { + target_id: target_id.clone(), + notify, + }), + } +} + +/// A prepared `handle_new_edge` operation. +enum PreparedNewEdge { + Leaf { + min_aggregation_number: u32, + target_aggregation_number: u32, + uppers: StackVec, + target_id: C::NodeRef, + }, + Upgraded { + uppers: StackVec, + target_id: C::NodeRef, + increase: PreparedInternalIncreaseAggregationNumber, + }, + Aggegating { + notify: PreparedNotifyNewFollower, + target_id: C::NodeRef, + }, +} + +impl PreparedOperation for PreparedNewEdge { + type Result = (); + #[tracing::instrument(level = tracing::Level::TRACE, name = "handle_new_edge", skip_all)] + fn apply(self, ctx: &C) { + let mut balance_queue = BalanceQueue::new(); + match self { + PreparedNewEdge::Leaf { + min_aggregation_number, + target_aggregation_number, + uppers, + target_id, + } => { + let _span = tracing::trace_span!("leaf").entered(); + { + let _span = + tracing::trace_span!("increase_aggregation_number_internal").entered(); + // TODO add to prepared + increase_aggregation_number_internal( + ctx, + &mut balance_queue, + ctx.node(&target_id), + &target_id, + min_aggregation_number, + target_aggregation_number, + ); + } + let mut affected_nodes = 0; + for upper_id in uppers { + affected_nodes += notify_new_follower( + ctx, + &mut balance_queue, + ctx.node(&upper_id), + &upper_id, + &target_id, + false, + ); + if affected_nodes > MAX_AFFECTED_NODES { + handle_expensive_node(ctx, &mut balance_queue, &target_id); + } + } + } + PreparedNewEdge::Upgraded { + uppers, + target_id, + increase, + } => { + // Since it was added to a leaf node, we would add it to the uppers + for upper_id in uppers { + notify_new_follower( + ctx, + &mut balance_queue, + ctx.node(&upper_id), + &upper_id, + &target_id, + true, + ); + } + // The balancing will attach it to the aggregated node later + increase.apply(ctx, &mut balance_queue); + } + PreparedNewEdge::Aggegating { target_id, notify } => { + let affected_nodes = notify.apply(ctx, &mut balance_queue); + if affected_nodes > MAX_AFFECTED_NODES { + handle_expensive_node(ctx, &mut balance_queue, &target_id); + } + } + } + let _span = tracing::trace_span!("balance_queue").entered(); + balance_queue.process(ctx); + } +} + +/// Called in the case when we detect that adding this node was expensive. It +/// optimizes the aggregation number of the node so it can be cheaper on the +/// next call. +fn handle_expensive_node( + ctx: &C, + balance_queue: &mut BalanceQueue, + node_id: &C::NodeRef, +) { + let node = ctx.node(node_id); + let uppers = node.uppers().iter().cloned().collect::>(); + let leaf = matches!(*node, AggregationNode::Leaf { .. }); + drop(node); + optimize_aggregation_number_for_uppers(ctx, balance_queue, node_id, leaf, uppers); +} diff --git a/crates/turbo-tasks-memory/src/aggregation/notify_lost_follower.rs b/crates/turbo-tasks-memory/src/aggregation/notify_lost_follower.rs new file mode 100644 index 0000000000000..bbddd4862ba18 --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/notify_lost_follower.rs @@ -0,0 +1,217 @@ +use std::{hash::Hash, thread::yield_now}; + +use super::{ + balance_queue::BalanceQueue, + in_progress::{finish_in_progress_without_node, start_in_progress, start_in_progress_all}, + AggegatingNode, AggregationContext, AggregationNode, AggregationNodeGuard, + PreparedInternalOperation, PreparedOperation, StackVec, +}; +use crate::count_hash_set::RemoveIfEntryResult; + +impl AggregationNode { + /// Called when a inner node of the upper node has lost a follower + /// It's expected that the upper node is flagged as "in progress". + pub(super) fn notify_lost_follower>( + &mut self, + ctx: &C, + balance_queue: &mut BalanceQueue, + upper_id: &C::NodeRef, + follower_id: &C::NodeRef, + ) -> Option> { + let AggregationNode::Aggegating(aggregating) = self else { + unreachable!(); + }; + match aggregating.followers.remove_if_entry(follower_id) { + RemoveIfEntryResult::PartiallyRemoved => { + self.finish_in_progress(ctx, balance_queue, upper_id); + None + } + RemoveIfEntryResult::Removed => { + let uppers = aggregating.uppers.iter().cloned().collect::>(); + start_in_progress_all(ctx, &uppers); + self.finish_in_progress(ctx, balance_queue, upper_id); + Some(PreparedNotifyLostFollower::RemovedFollower { + uppers, + follower_id: follower_id.clone(), + }) + } + RemoveIfEntryResult::NotPresent => Some(PreparedNotifyLostFollower::NotFollower { + upper_id: upper_id.clone(), + follower_id: follower_id.clone(), + }), + } + } + + /// Called when a inner node of the upper node has lost a follower. + /// It's expected that the upper node is NOT flagged as "in progress". + pub(super) fn notify_lost_follower_not_in_progress< + C: AggregationContext, + >( + &mut self, + ctx: &C, + upper_id: &C::NodeRef, + follower_id: &C::NodeRef, + ) -> Option> { + let AggregationNode::Aggegating(aggregating) = self else { + unreachable!(); + }; + match aggregating.followers.remove_if_entry(follower_id) { + RemoveIfEntryResult::PartiallyRemoved => None, + RemoveIfEntryResult::Removed => { + let uppers = aggregating.uppers.iter().cloned().collect::>(); + start_in_progress_all(ctx, &uppers); + Some(PreparedNotifyLostFollower::RemovedFollower { + uppers, + follower_id: follower_id.clone(), + }) + } + RemoveIfEntryResult::NotPresent => { + start_in_progress(ctx, upper_id); + Some(PreparedNotifyLostFollower::NotFollower { + upper_id: upper_id.clone(), + follower_id: follower_id.clone(), + }) + } + } + } +} + +/// A prepared `notify_lost_follower` operation. +pub(super) enum PreparedNotifyLostFollower { + RemovedFollower { + uppers: StackVec, + follower_id: C::NodeRef, + }, + NotFollower { + upper_id: C::NodeRef, + follower_id: C::NodeRef, + }, +} + +impl PreparedInternalOperation for PreparedNotifyLostFollower { + type Result = (); + fn apply(self, ctx: &C, balance_queue: &mut BalanceQueue) { + match self { + PreparedNotifyLostFollower::RemovedFollower { + uppers, + follower_id, + } => { + for upper_id in uppers { + notify_lost_follower( + ctx, + balance_queue, + ctx.node(&upper_id), + &upper_id, + &follower_id, + ); + } + } + PreparedNotifyLostFollower::NotFollower { + upper_id, + follower_id, + } => { + loop { + let mut follower = ctx.node(&follower_id); + match follower.uppers_mut().remove_if_entry(&upper_id) { + RemoveIfEntryResult::PartiallyRemoved => { + finish_in_progress_without_node(ctx, balance_queue, &upper_id); + drop(follower); + return; + } + RemoveIfEntryResult::Removed => { + let remove_change = get_aggregated_remove_change(ctx, &follower); + let followers = match &*follower { + AggregationNode::Leaf { .. } => { + follower.children().collect::>() + } + AggregationNode::Aggegating(aggregating) => { + let AggegatingNode { ref followers, .. } = **aggregating; + followers.iter().cloned().collect::>() + } + }; + drop(follower); + + let mut upper = ctx.node(&upper_id); + let remove_change = remove_change + .map(|remove_change| upper.apply_change(ctx, remove_change)); + let prepared = followers + .into_iter() + .filter_map(|follower_id| { + upper.notify_lost_follower_not_in_progress( + ctx, + &upper_id, + &follower_id, + ) + }) + .collect::>(); + upper.finish_in_progress(ctx, balance_queue, &upper_id); + drop(upper); + prepared.apply(ctx, balance_queue); + remove_change.apply(ctx); + return; + } + RemoveIfEntryResult::NotPresent => { + drop(follower); + let mut upper = ctx.node(&upper_id); + let AggregationNode::Aggegating(aggregating) = &mut *upper else { + unreachable!(); + }; + match aggregating.followers.remove_if_entry(&follower_id) { + RemoveIfEntryResult::PartiallyRemoved => { + upper.finish_in_progress(ctx, balance_queue, &upper_id); + return; + } + RemoveIfEntryResult::Removed => { + let uppers = + aggregating.uppers.iter().cloned().collect::>(); + start_in_progress_all(ctx, &uppers); + upper.finish_in_progress(ctx, balance_queue, &upper_id); + drop(upper); + for upper_id in uppers { + notify_lost_follower( + ctx, + balance_queue, + ctx.node(&upper_id), + &upper_id, + &follower_id, + ); + } + return; + } + RemoveIfEntryResult::NotPresent => { + drop(upper); + yield_now() + // Retry, concurrency + } + } + } + } + } + } + } + } +} + +/// Notifies the upper node that a follower has been lost. +/// It's expected that the upper node is flagged as "in progress". +pub fn notify_lost_follower( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut upper: C::Guard<'_>, + upper_id: &C::NodeRef, + follower_id: &C::NodeRef, +) { + let p = upper.notify_lost_follower(ctx, balance_queue, upper_id, follower_id); + drop(upper); + p.apply(ctx, balance_queue); +} + +fn get_aggregated_remove_change( + ctx: &C, + guard: &C::Guard<'_>, +) -> Option { + match &**guard { + AggregationNode::Leaf { .. } => guard.get_remove_change(), + AggregationNode::Aggegating(aggegating) => ctx.data_to_remove_change(&aggegating.data), + } +} diff --git a/crates/turbo-tasks-memory/src/aggregation/notify_new_follower.rs b/crates/turbo-tasks-memory/src/aggregation/notify_new_follower.rs new file mode 100644 index 0000000000000..0753e292c2bc8 --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/notify_new_follower.rs @@ -0,0 +1,243 @@ +use std::{cmp::Ordering, hash::Hash}; + +use super::{ + balance_queue::BalanceQueue, + followers::add_follower, + in_progress::{finish_in_progress_without_node, start_in_progress}, + increase_aggregation_number_internal, + optimize::optimize_aggregation_number_for_uppers, + uppers::add_upper, + AggregationContext, AggregationNode, PreparedInternalOperation, StackVec, +}; + +const MAX_AFFECTED_NODES: usize = 4096; + +impl AggregationNode { + // Called when a inner node of the upper node has a new follower. + // It's expected that the upper node is flagged as "in progress". + pub(super) fn notify_new_follower>( + &mut self, + ctx: &C, + balance_queue: &mut BalanceQueue, + upper_id: &C::NodeRef, + follower_id: &C::NodeRef, + already_optimizing_for_upper: bool, + ) -> Option> { + let AggregationNode::Aggegating(aggregating) = self else { + unreachable!(); + }; + if aggregating.followers.add_if_entry(follower_id) { + self.finish_in_progress(ctx, balance_queue, upper_id); + None + } else { + let upper_aggregation_number = aggregating.aggregation_number; + if upper_aggregation_number == u32::MAX { + Some(PreparedNotifyNewFollower::Inner { + upper_id: upper_id.clone(), + follower_id: follower_id.clone(), + already_optimizing_for_upper, + }) + } else { + Some(PreparedNotifyNewFollower::FollowerOrInner { + upper_aggregation_number, + upper_id: upper_id.clone(), + follower_id: follower_id.clone(), + already_optimizing_for_upper, + }) + } + } + } + + // Called when a inner node of the upper node has a new follower. + // It's expected that the upper node is NOT flagged as "in progress". + pub(super) fn notify_new_follower_not_in_progress< + C: AggregationContext, + >( + &mut self, + ctx: &C, + upper_id: &C::NodeRef, + follower_id: &C::NodeRef, + ) -> Option> { + let AggregationNode::Aggegating(aggregating) = self else { + unreachable!(); + }; + if aggregating.followers.add_if_entry(follower_id) { + None + } else { + start_in_progress(ctx, upper_id); + let upper_aggregation_number = aggregating.aggregation_number; + if upper_aggregation_number == u32::MAX { + Some(PreparedNotifyNewFollower::Inner { + upper_id: upper_id.clone(), + follower_id: follower_id.clone(), + already_optimizing_for_upper: false, + }) + } else { + Some(PreparedNotifyNewFollower::FollowerOrInner { + upper_aggregation_number, + upper_id: upper_id.clone(), + follower_id: follower_id.clone(), + already_optimizing_for_upper: false, + }) + } + } + } +} + +/// A prepared `notify_new_follower` operation. +pub(super) enum PreparedNotifyNewFollower { + Inner { + upper_id: C::NodeRef, + follower_id: C::NodeRef, + already_optimizing_for_upper: bool, + }, + FollowerOrInner { + upper_aggregation_number: u32, + upper_id: C::NodeRef, + follower_id: C::NodeRef, + already_optimizing_for_upper: bool, + }, +} + +impl PreparedInternalOperation for PreparedNotifyNewFollower { + type Result = usize; + fn apply(self, ctx: &C, balance_queue: &mut BalanceQueue) -> Self::Result { + match self { + PreparedNotifyNewFollower::Inner { + upper_id, + follower_id, + already_optimizing_for_upper, + } => { + let follower = ctx.node(&follower_id); + let affected_nodes = add_upper( + ctx, + balance_queue, + follower, + &follower_id, + &upper_id, + already_optimizing_for_upper, + ); + finish_in_progress_without_node(ctx, balance_queue, &upper_id); + if !already_optimizing_for_upper && affected_nodes > MAX_AFFECTED_NODES { + let follower = ctx.node(&follower_id); + let uppers = follower.uppers().iter().cloned().collect::>(); + let leaf: bool = follower.is_leaf(); + drop(follower); + if optimize_aggregation_number_for_uppers( + ctx, + balance_queue, + &follower_id, + leaf, + uppers, + ) { + return 1; + } + } + affected_nodes + } + PreparedNotifyNewFollower::FollowerOrInner { + mut upper_aggregation_number, + upper_id, + follower_id, + already_optimizing_for_upper, + } => loop { + let follower = ctx.node(&follower_id); + let follower_aggregation_number = follower.aggregation_number(); + if follower_aggregation_number < upper_aggregation_number { + let affected_nodes = add_upper( + ctx, + balance_queue, + follower, + &follower_id, + &upper_id, + already_optimizing_for_upper, + ); + finish_in_progress_without_node(ctx, balance_queue, &upper_id); + if !already_optimizing_for_upper && affected_nodes > MAX_AFFECTED_NODES { + let follower = ctx.node(&follower_id); + let uppers = follower.uppers().iter().cloned().collect::>(); + let leaf = follower.is_leaf(); + drop(follower); + if optimize_aggregation_number_for_uppers( + ctx, + balance_queue, + &follower_id, + leaf, + uppers, + ) { + return 1; + } + } + return affected_nodes; + } else { + drop(follower); + let mut upper = ctx.node(&upper_id); + let AggregationNode::Aggegating(aggregating) = &mut *upper else { + unreachable!(); + }; + upper_aggregation_number = aggregating.aggregation_number; + if upper_aggregation_number == u32::MAX { + // retry, concurrency + } else { + match follower_aggregation_number.cmp(&upper_aggregation_number) { + Ordering::Less => { + // retry, concurrency + } + Ordering::Equal => { + drop(upper); + let follower = ctx.node(&follower_id); + let follower_aggregation_number = follower.aggregation_number(); + if follower_aggregation_number == upper_aggregation_number { + increase_aggregation_number_internal( + ctx, + balance_queue, + follower, + &follower_id, + upper_aggregation_number + 1, + upper_aggregation_number + 1, + ); + // retry + } else { + // retry, concurrency + } + } + Ordering::Greater => { + upper.finish_in_progress(ctx, balance_queue, &upper_id); + return add_follower( + ctx, + balance_queue, + upper, + &upper_id, + &follower_id, + already_optimizing_for_upper, + ); + } + } + } + } + }, + } + } +} + +/// Notifies the upper node that it has a new follower. +/// Returns the number of affected nodes. +/// The upper node is expected to be flagged as "in progress". +pub fn notify_new_follower( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut upper: C::Guard<'_>, + upper_id: &C::NodeRef, + follower_id: &C::NodeRef, + already_optimizing_for_upper: bool, +) -> usize { + let p = upper.notify_new_follower( + ctx, + balance_queue, + upper_id, + follower_id, + already_optimizing_for_upper, + ); + drop(upper); + p.apply(ctx, balance_queue).unwrap_or_default() +} diff --git a/crates/turbo-tasks-memory/src/aggregation/optimize.rs b/crates/turbo-tasks-memory/src/aggregation/optimize.rs new file mode 100644 index 0000000000000..fadb915180239 --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/optimize.rs @@ -0,0 +1,139 @@ +use super::{ + balance_queue::BalanceQueue, + increase::{increase_aggregation_number_internal, LEAF_NUMBER}, + AggregationContext, StackVec, +}; + +pub const MAX_UPPERS: usize = 512; + +pub const MAX_FOLLOWERS: usize = 128; + +/// Optimize the aggregation number for a node based on a list of upper nodes. +/// The goal is to reduce the number of upper nodes, so we try to find a +/// aggregation number that is higher than some of the upper nodes. +/// Returns true if the aggregation number was increased. +pub fn optimize_aggregation_number_for_uppers( + ctx: &C, + balance_queue: &mut BalanceQueue, + node_id: &C::NodeRef, + leaf: bool, + uppers: StackVec, +) -> bool { + let count = uppers.len(); + let mut root_count = 0; + let mut min = u32::MAX; + let mut max = 0; + let mut uppers_uppers = 0; + for upper_id in uppers.into_iter() { + let upper = ctx.node(&upper_id); + let aggregation_number = upper.aggregation_number(); + if aggregation_number == u32::MAX { + root_count += 1; + } else { + let upper_uppers = upper.uppers().len(); + uppers_uppers += upper_uppers; + if aggregation_number < min { + min = aggregation_number; + } + if aggregation_number > max { + max = aggregation_number; + } + } + } + if min == u32::MAX { + min = LEAF_NUMBER - 1; + } + if max < LEAF_NUMBER { + max = LEAF_NUMBER - 1; + } + let aggregation_number = (min + max) / 2 + 1; + if leaf { + increase_aggregation_number_internal( + ctx, + balance_queue, + ctx.node(node_id), + node_id, + aggregation_number, + aggregation_number, + ); + return true; + } else { + let normal_count = count - root_count; + if normal_count > 0 { + let avg_uppers_uppers = uppers_uppers / normal_count; + if count > avg_uppers_uppers && root_count * 2 < count { + increase_aggregation_number_internal( + ctx, + balance_queue, + ctx.node(node_id), + node_id, + aggregation_number, + aggregation_number, + ); + return true; + } + } + } + false +} + +/// Optimize the aggregation number for a node based on a list of followers. +/// The goal is to reduce the number of followers, so we try to find a +/// aggregation number that is higher than some of the followers. +/// Returns true if the aggregation number was increased. +pub fn optimize_aggregation_number_for_followers( + ctx: &C, + balance_queue: &mut BalanceQueue, + node_id: &C::NodeRef, + followers: StackVec, + force: bool, +) -> bool { + let count = followers.len(); + let mut root_count = 0; + let mut min = u32::MAX; + let mut max = 0; + let mut followers_followers = 0; + for follower_id in followers.into_iter() { + let follower = ctx.node(&follower_id); + let aggregation_number = follower.aggregation_number(); + if aggregation_number == u32::MAX { + root_count += 1; + } else { + let follower_followers = follower.followers().map_or(0, |f| f.len()); + followers_followers += follower_followers; + if aggregation_number < min { + min = aggregation_number; + } + if aggregation_number > max { + max = aggregation_number; + } + } + } + if min == u32::MAX { + min = LEAF_NUMBER - 1; + } + if min < LEAF_NUMBER { + min = LEAF_NUMBER - 1; + } + if max < min { + max = min; + } + let normal_count = count - root_count; + if normal_count > 0 { + let avg_followers_followers = followers_followers / normal_count; + let makes_sense = count > avg_followers_followers || force; + if makes_sense && root_count * 2 < count { + let aggregation_number = (min + max) / 2 + 1; + increase_aggregation_number_internal( + ctx, + balance_queue, + ctx.node(node_id), + node_id, + aggregation_number, + aggregation_number, + ); + return true; + } + } + false +} diff --git a/crates/turbo-tasks-memory/src/aggregation/root_query.rs b/crates/turbo-tasks-memory/src/aggregation/root_query.rs new file mode 100644 index 0000000000000..74cacffb6961a --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/root_query.rs @@ -0,0 +1,51 @@ +use std::ops::ControlFlow; + +use auto_hash_map::AutoSet; + +use super::{AggregationContext, AggregationNode, StackVec}; + +/// A query about aggregation data in a root node. +pub trait RootQuery { + type Data; + type Result; + + /// Processes the aggregated data of a root node. Can decide to stop the + /// query. + fn query(&mut self, data: &Self::Data) -> ControlFlow<()>; + /// Returns the result of the query. + fn result(self) -> Self::Result; +} + +/// Queries the root node of an aggregation tree. +pub fn query_root_info>( + ctx: &C, + mut query: Q, + node_id: C::NodeRef, +) -> Q::Result { + let mut queue = StackVec::new(); + queue.push(node_id); + let mut visited = AutoSet::new(); + while let Some(node_id) = queue.pop() { + let node = ctx.node(&node_id); + match &*node { + AggregationNode::Leaf { uppers, .. } => { + for upper_id in uppers.iter() { + if visited.insert(upper_id.clone()) { + queue.push(upper_id.clone()); + } + } + } + AggregationNode::Aggegating(aggegrating) => { + if let ControlFlow::Break(_) = query.query(&aggegrating.data) { + return query.result(); + } + for upper_id in aggegrating.uppers.iter() { + if visited.insert(upper_id.clone()) { + queue.push(upper_id.clone()); + } + } + } + } + } + query.result() +} diff --git a/crates/turbo-tasks-memory/src/aggregation/tests.rs b/crates/turbo-tasks-memory/src/aggregation/tests.rs new file mode 100644 index 0000000000000..bd6d5c2c821ad --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/tests.rs @@ -0,0 +1,1073 @@ +use std::{ + collections::HashSet, + fmt::Debug, + hash::Hash, + iter::once, + ops::{ControlFlow, Deref, DerefMut}, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::Instant, +}; + +use indexmap::IndexSet; +use nohash_hasher::IsEnabled; +use parking_lot::{Mutex, MutexGuard}; +use rand::{rngs::SmallRng, Rng, SeedableRng}; +use ref_cast::RefCast; +use rstest::*; + +use self::aggregation_data::prepare_aggregation_data; +use super::{ + aggregation_data, handle_new_edge, lost_edge::handle_lost_edges, AggregationContext, + AggregationNode, AggregationNodeGuard, RootQuery, +}; +use crate::aggregation::{query_root_info, PreparedOperation, StackVec}; + +fn find_root(mut node: NodeRef) -> NodeRef { + loop { + let lock = node.0.inner.lock(); + let uppers = lock.aggregation_node.uppers(); + if uppers.is_empty() { + drop(lock); + return node; + } + let upper = uppers.iter().next().unwrap().clone(); + drop(lock); + node = upper; + } +} + +fn check_invariants<'a>( + ctx: &NodeAggregationContext<'a>, + node_ids: impl IntoIterator, +) { + let mut queue = node_ids.into_iter().collect::>(); + // print(ctx, &queue[0], true); + let mut visited = HashSet::new(); + while let Some(node_id) = queue.pop() { + assert_eq!(node_id.0.atomic.load(Ordering::SeqCst), 0); + let node = ctx.node(&node_id); + for child_id in node.children() { + if visited.insert(child_id.clone()) { + queue.push(child_id.clone()); + } + } + + let aggregation_number = node.aggregation_number(); + let node_value = node.guard.value; + let uppers = match &*node { + AggregationNode::Leaf { uppers, .. } => { + let uppers = uppers.iter().cloned().collect::>(); + drop(node); + uppers + } + AggregationNode::Aggegating(aggegrating) => { + let uppers = aggegrating.uppers.iter().cloned().collect::>(); + let followers = aggegrating + .followers + .iter() + .cloned() + .collect::>(); + drop(node); + for follower_id in followers { + let follower_aggregation_number; + let follower_uppers; + let follower_value; + { + let follower = ctx.node(&follower_id); + + follower_aggregation_number = follower.aggregation_number(); + follower_uppers = + follower.uppers().iter().cloned().collect::>(); + follower_value = follower.guard.value; + } + + // A follower should have a bigger aggregation number + let condition = follower_aggregation_number > aggregation_number + || aggregation_number == u32::MAX; + if !condition { + let msg = format!( + "follower #{} {} -> #{} {}", + node_value, + aggregation_number, + follower_value, + follower_aggregation_number + ); + print(ctx, &find_root(node_id.clone()), true); + panic!("{msg}"); + } + + // All followers should also be connected to all uppers + let missing_uppers = uppers.iter().filter(|&upper_id| { + if follower_uppers + .iter() + .any(|follower_upper_id| follower_upper_id == upper_id) + { + return false; + } + let upper = ctx.node(upper_id); + if let Some(followers) = upper.followers() { + !followers + .iter() + .any(|follower_upper_id| follower_upper_id == &follower_id) + } else { + false + } + }); + for missing_upper in missing_uppers { + let upper_value = { + let upper = ctx.node(missing_upper); + upper.guard.value + }; + let msg = format!( + "follower #{} -> #{} is not connected to upper #{}", + node_value, follower_value, upper_value, + ); + print(ctx, &find_root(node_id.clone()), true); + panic!("{msg}"); + } + + // And visit them too + if visited.insert(follower_id.clone()) { + queue.push(follower_id); + } + } + uppers + } + }; + for upper_id in uppers { + { + let upper = ctx.node(&upper_id); + let upper_aggregation_number = upper.aggregation_number(); + let condition = + upper_aggregation_number > aggregation_number || aggregation_number == u32::MAX; + if !condition { + let msg = format!( + "upper #{} {} -> #{} {}", + node_value, aggregation_number, upper.guard.value, upper_aggregation_number + ); + drop(upper); + print(ctx, &find_root(upper_id.clone()), true); + panic!("{msg}"); + } + } + if visited.insert(upper_id.clone()) { + queue.push(upper_id); + } + } + } +} + +fn print_graph( + ctx: &C, + entries: impl IntoIterator, + show_internal: bool, + name_fn: impl Fn(&C::NodeRef) -> String, +) { + let mut queue = entries.into_iter().collect::>(); + let mut visited = queue.iter().cloned().collect::>(); + while let Some(node_id) = queue.pop() { + let name = name_fn(&node_id); + let node = ctx.node(&node_id); + let n = node.aggregation_number(); + let n = if n == u32::MAX { + "♾".to_string() + } else { + n.to_string() + }; + let color = if matches!(*node, AggregationNode::Leaf { .. }) { + "gray" + } else { + "#99ff99" + }; + let children = node.children().collect::>(); + let uppers = node.uppers().iter().cloned().collect::>(); + let followers = match &*node { + AggregationNode::Aggegating(aggegrating) => aggegrating + .followers + .iter() + .cloned() + .collect::>(), + AggregationNode::Leaf { .. } => StackVec::new(), + }; + drop(node); + + if show_internal { + println!( + "\"{}\" [label=\"{}\\n{}\", style=filled, fillcolor=\"{}\"];", + name, name, n, color + ); + } else { + println!( + "\"{}\" [label=\"{}\\n{}\\n{}U {}F\", style=filled, fillcolor=\"{}\"];", + name, + name, + n, + uppers.len(), + followers.len(), + color, + ); + } + + for child_id in children { + let child_name = name_fn(&child_id); + println!("\"{}\" -> \"{}\";", name, child_name); + if visited.insert(child_id.clone()) { + queue.push(child_id); + } + } + if show_internal { + for upper_id in uppers { + let upper_name = name_fn(&upper_id); + println!( + "\"{}\" -> \"{}\" [style=dashed, color=green];", + name, upper_name + ); + if visited.insert(upper_id.clone()) { + queue.push(upper_id); + } + } + for follower_id in followers { + let follower_name = name_fn(&follower_id); + println!( + "\"{}\" -> \"{}\" [style=dashed, color=red];", + name, follower_name + ); + if visited.insert(follower_id.clone()) { + queue.push(follower_id); + } + } + } + } +} + +struct Node { + atomic: AtomicU32, + inner: Mutex, +} + +impl Node { + fn new(value: u32) -> Arc { + Arc::new(Node { + atomic: AtomicU32::new(0), + inner: Mutex::new(NodeInner { + children: Vec::new(), + aggregation_node: AggregationNode::new(), + value, + }), + }) + } + + fn new_with_children( + aggregation_context: &NodeAggregationContext, + value: u32, + children: Vec>, + ) -> Arc { + let node = Self::new(value); + for child in children { + node.add_child(aggregation_context, child); + } + node + } + + fn add_child(self: &Arc, aggregation_context: &NodeAggregationContext, child: Arc) { + self.add_child_unchecked(aggregation_context, child); + check_invariants(aggregation_context, once(find_root(NodeRef(self.clone())))); + } + + fn add_child_unchecked( + self: &Arc, + aggregation_context: &NodeAggregationContext, + child: Arc, + ) { + let mut guard = self.inner.lock(); + guard.children.push(child.clone()); + let number_of_children = guard.children.len(); + let mut guard = unsafe { NodeGuard::new(guard, self.clone()) }; + let prepared = handle_new_edge( + aggregation_context, + &mut guard, + &NodeRef(self.clone()), + &NodeRef(child), + number_of_children, + ); + drop(guard); + prepared.apply(aggregation_context); + } + + fn prepare_add_child<'c>( + self: &Arc, + aggregation_context: &'c NodeAggregationContext<'c>, + child: Arc, + ) -> impl PreparedOperation> { + let mut guard = self.inner.lock(); + guard.children.push(child.clone()); + let number_of_children = guard.children.len(); + let mut guard = unsafe { NodeGuard::new(guard, self.clone()) }; + handle_new_edge( + aggregation_context, + &mut guard, + &NodeRef(self.clone()), + &NodeRef(child), + number_of_children, + ) + } + + fn prepare_aggregation_number<'c>( + self: &Arc, + aggregation_context: &'c NodeAggregationContext<'c>, + aggregation_number: u32, + ) -> impl PreparedOperation> { + let mut guard = self.inner.lock(); + guard.aggregation_node.increase_aggregation_number( + aggregation_context, + &NodeRef(self.clone()), + aggregation_number, + ) + } + + fn remove_child( + self: &Arc, + aggregation_context: &NodeAggregationContext, + child: &Arc, + ) { + self.remove_child_unchecked(aggregation_context, child); + check_invariants(aggregation_context, once(NodeRef(self.clone()))); + } + + fn remove_child_unchecked( + self: &Arc, + aggregation_context: &NodeAggregationContext, + child: &Arc, + ) { + let mut guard = self.inner.lock(); + if let Some(idx) = guard + .children + .iter() + .position(|item| Arc::ptr_eq(item, child)) + { + guard.children.swap_remove(idx); + handle_lost_edges( + aggregation_context, + unsafe { NodeGuard::new(guard, self.clone()) }, + &NodeRef(self.clone()), + [NodeRef(child.clone())], + ); + } + } + + fn incr(self: &Arc, aggregation_context: &NodeAggregationContext) { + let mut guard = self.inner.lock(); + guard.value += 10000; + let prepared = guard + .aggregation_node + .apply_change(aggregation_context, Change { value: 10000 }); + drop(guard); + prepared.apply(aggregation_context); + check_invariants(aggregation_context, once(NodeRef(self.clone()))); + } +} + +#[derive(Copy, Clone)] +struct Change { + value: i32, +} + +impl Change { + fn is_empty(&self) -> bool { + self.value == 0 + } +} + +struct NodeInner { + children: Vec>, + aggregation_node: AggregationNode, + value: u32, +} + +struct NodeAggregationContext<'a> { + additions: AtomicU32, + #[allow(dead_code)] + something_with_lifetime: &'a u32, + add_value: bool, +} + +#[derive(Clone, RefCast)] +#[repr(transparent)] +struct NodeRef(Arc); + +impl Debug for NodeRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NodeRef({})", self.0.inner.lock().value) + } +} + +impl Hash for NodeRef { + fn hash(&self, state: &mut H) { + Arc::as_ptr(&self.0).hash(state); + } +} + +impl IsEnabled for NodeRef {} + +impl PartialEq for NodeRef { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl Eq for NodeRef {} + +struct NodeGuard { + guard: MutexGuard<'static, NodeInner>, + // This field is important to keep the node alive + #[allow(dead_code)] + node: Arc, +} + +impl NodeGuard { + unsafe fn new(guard: MutexGuard<'_, NodeInner>, node: Arc) -> Self { + NodeGuard { + guard: unsafe { std::mem::transmute(guard) }, + node, + } + } +} + +impl Deref for NodeGuard { + type Target = AggregationNode; + + fn deref(&self) -> &Self::Target { + &self.guard.aggregation_node + } +} + +impl DerefMut for NodeGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.guard.aggregation_node + } +} + +impl AggregationNodeGuard for NodeGuard { + type Data = Aggregated; + type NodeRef = NodeRef; + type DataChange = Change; + type ChildrenIter<'a> = impl Iterator + 'a; + + fn children(&self) -> Self::ChildrenIter<'_> { + self.guard + .children + .iter() + .map(|child| NodeRef(child.clone())) + } + + fn get_remove_change(&self) -> Option { + let change = Change { + value: -(self.guard.value as i32), + }; + if change.is_empty() { + None + } else { + Some(change) + } + } + + fn get_add_change(&self) -> Option { + let change = Change { + value: self.guard.value as i32, + }; + if change.is_empty() { + None + } else { + Some(change) + } + } + + fn get_initial_data(&self) -> Self::Data { + Aggregated { + value: self.guard.value as i32, + active: false, + } + } +} + +impl<'a> AggregationContext for NodeAggregationContext<'a> { + type Guard<'l> = NodeGuard where Self: 'l; + type Data = Aggregated; + type NodeRef = NodeRef; + type DataChange = Change; + + fn node<'b>(&'b self, reference: &Self::NodeRef) -> Self::Guard<'b> { + let r = reference.0.clone(); + let guard = reference.0.inner.lock(); + unsafe { NodeGuard::new(guard, r) } + } + + fn atomic_in_progress_counter<'l>(&self, id: &'l Self::NodeRef) -> &'l AtomicU32 + where + Self: 'l, + { + &id.0.atomic + } + + fn apply_change(&self, data: &mut Aggregated, change: &Change) -> Option { + if data.value != 0 { + self.additions.fetch_add(1, Ordering::SeqCst); + } + if self.add_value { + data.value += change.value; + Some(*change) + } else { + None + } + } + + fn data_to_add_change(&self, data: &Self::Data) -> Option { + let change = Change { value: data.value }; + if change.is_empty() { + None + } else { + Some(change) + } + } + + fn data_to_remove_change(&self, data: &Self::Data) -> Option { + let change = Change { value: -data.value }; + if change.is_empty() { + None + } else { + Some(change) + } + } +} + +#[derive(Default)] +struct ActiveQuery { + active: bool, +} + +impl RootQuery for ActiveQuery { + type Data = Aggregated; + type Result = bool; + + fn query(&mut self, data: &Self::Data) -> ControlFlow<()> { + if data.active { + self.active = true; + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + } + + fn result(self) -> Self::Result { + self.active + } +} + +#[derive(Default)] +struct Aggregated { + value: i32, + active: bool, +} + +#[test] +fn chain() { + let something_with_lifetime = 0; + let ctx = NodeAggregationContext { + additions: AtomicU32::new(0), + something_with_lifetime: &something_with_lifetime, + add_value: true, + }; + let root = Node::new(1); + let mut current = root.clone(); + for i in 2..=100 { + let node = Node::new(i); + current.add_child(&ctx, node.clone()); + current = node; + } + let leaf = Node::new(10000); + current.add_child(&ctx, leaf.clone()); + let current = NodeRef(root); + + { + let root_info = query_root_info(&ctx, ActiveQuery::default(), NodeRef(leaf.clone())); + assert!(!root_info); + } + + { + let aggregated = aggregation_data(&ctx, ¤t); + assert_eq!(aggregated.value, 15050); + } + assert_eq!(ctx.additions.load(Ordering::SeqCst), 122); + ctx.additions.store(0, Ordering::SeqCst); + check_invariants(&ctx, once(current.clone())); + + { + let root_info = query_root_info(&ctx, ActiveQuery::default(), NodeRef(leaf.clone())); + assert!(!root_info); + } + check_invariants(&ctx, once(current.clone())); + + leaf.incr(&ctx); + // The change need to propagate through 2 aggregated nodes + assert_eq!(ctx.additions.load(Ordering::SeqCst), 2); + ctx.additions.store(0, Ordering::SeqCst); + + { + let mut aggregated = aggregation_data(&ctx, ¤t); + assert_eq!(aggregated.value, 25050); + aggregated.active = true; + } + assert_eq!(ctx.additions.load(Ordering::SeqCst), 0); + ctx.additions.store(0, Ordering::SeqCst); + + { + let root_info = query_root_info(&ctx, ActiveQuery::default(), NodeRef(leaf.clone())); + assert!(root_info); + } + + let i = 101; + let current = Node::new_with_children(&ctx, i, vec![current.0]); + let current = NodeRef(current); + + { + let aggregated = aggregation_data(&ctx, ¤t); + assert_eq!(aggregated.value, 25151); + } + // This should be way less the 100 to prove that we are reusing trees + assert_eq!(ctx.additions.load(Ordering::SeqCst), 1); + ctx.additions.store(0, Ordering::SeqCst); + + leaf.incr(&ctx); + // This should be less the 20 to prove that we are reusing trees + assert_eq!(ctx.additions.load(Ordering::SeqCst), 3); + ctx.additions.store(0, Ordering::SeqCst); + + { + let root_info = query_root_info(&ctx, ActiveQuery::default(), NodeRef(leaf.clone())); + assert!(root_info); + } + + print(&ctx, ¤t, true); + check_invariants(&ctx, once(current.clone())); +} + +#[test] +fn chain_double_connected() { + let something_with_lifetime = 0; + let ctx = NodeAggregationContext { + additions: AtomicU32::new(0), + something_with_lifetime: &something_with_lifetime, + add_value: true, + }; + let root = Node::new(1); + let mut nodes = vec![root.clone()]; + let mut current = root.clone(); + let mut current2 = Node::new(2); + current.add_child(&ctx, current2.clone()); + nodes.push(current2.clone()); + for i in 3..=100 { + let node = Node::new(i); + nodes.push(node.clone()); + current.add_child(&ctx, node.clone()); + current2.add_child(&ctx, node.clone()); + current = current2; + current2 = node; + } + let current = NodeRef(root); + + { + let aggregated = aggregation_data(&ctx, ¤t); + assert_eq!(aggregated.value, 13188); + } + check_invariants(&ctx, once(current.clone())); + assert_eq!(ctx.additions.load(Ordering::SeqCst), 285); + ctx.additions.store(0, Ordering::SeqCst); + + print(&ctx, ¤t, true); + + for i in 2..nodes.len() { + nodes[i - 2].remove_child(&ctx, &nodes[i]); + nodes[i - 1].remove_child(&ctx, &nodes[i]); + } + nodes[0].remove_child(&ctx, &nodes[1]); + + { + let aggregated = aggregation_data(&ctx, ¤t); + assert_eq!(aggregated.value, 1); + } +} + +const RECT_SIZE: usize = 30; +const RECT_MULT: usize = 100; + +#[test] +fn rectangle_tree() { + let something_with_lifetime = 0; + let ctx = NodeAggregationContext { + additions: AtomicU32::new(0), + something_with_lifetime: &something_with_lifetime, + add_value: false, + }; + let mut nodes: Vec>> = Vec::new(); + let mut extra_nodes = Vec::new(); + for y in 0..RECT_SIZE { + let mut line: Vec> = Vec::new(); + for x in 0..RECT_SIZE { + let mut parents = Vec::new(); + if x > 0 { + parents.push(line[x - 1].clone()); + } + if y > 0 { + parents.push(nodes[y - 1][x].clone()); + } + let value = (x + y * RECT_MULT) as u32; + let node = Node::new(value); + if x == 0 || y == 0 { + let extra_node = Node::new(value + 100000); + prepare_aggregation_data(&ctx, &NodeRef(extra_node.clone())); + extra_node.add_child(&ctx, node.clone()); + extra_nodes.push(extra_node); + prepare_aggregation_data(&ctx, &NodeRef(node.clone())); + } + for parent in parents { + parent.add_child_unchecked(&ctx, node.clone()); + } + if x == 0 || y == 0 { + prepare_aggregation_data(&ctx, &NodeRef(node.clone())); + } + line.push(node); + } + nodes.push(line); + } + + check_invariants(&ctx, extra_nodes.iter().cloned().map(NodeRef)); + + let root = NodeRef(extra_nodes[0].clone()); + print(&ctx, &root, false); +} + +#[rstest] +#[case::many_roots_initial(100000, 0, 2, 1)] +#[case::many_roots_later(1, 100000, 2, 1)] +#[case::many_roots_later2(0, 100000, 2, 1)] +#[case::many_roots(50000, 50000, 2, 1)] +#[case::many_children(2, 0, 100000, 1)] +#[case::many_roots_and_children(5000, 5000, 10000, 1)] +#[case::many_roots_and_subgraph(5000, 5000, 100, 2)] +#[case::large_subgraph_a(9, 1, 10, 5)] +#[case::large_subgraph_b(5, 5, 10, 5)] +#[case::large_subgraph_c(1, 9, 10, 5)] +#[case::large_subgraph_d(6, 0, 10, 5)] +#[case::large_subgraph_e(0, 10, 10, 5)] +#[case::many_roots_large_subgraph(5000, 5000, 10, 5)] +fn performance( + #[case] initial_root_count: u32, + #[case] additional_root_count: u32, + #[case] children_count: u32, + #[case] children_layers_count: u32, +) { + fn print_aggregation_numbers(node: Arc) { + print!("Aggregation numbers "); + let mut current = node.clone(); + loop { + let guard = current.inner.lock(); + let n = guard.aggregation_node.aggregation_number(); + let f = guard.aggregation_node.followers().map_or(0, |f| f.len()); + let u = guard.aggregation_node.uppers().len(); + print!(" -> {} [{}U {}F]", n, u, f); + if guard.children.is_empty() { + break; + } + let child = guard.children[guard.children.len() / 2].clone(); + drop(guard); + current = child; + } + println!(); + } + + let something_with_lifetime = 0; + let ctx = NodeAggregationContext { + additions: AtomicU32::new(0), + something_with_lifetime: &something_with_lifetime, + add_value: false, + }; + let mut roots: Vec> = Vec::new(); + let inner_node = Node::new(0); + // Setup + for i in 0..initial_root_count { + let node = Node::new(2 + i); + roots.push(node.clone()); + aggregation_data(&ctx, &NodeRef(node.clone())).active = true; + node.add_child_unchecked(&ctx, inner_node.clone()); + } + let start = Instant::now(); + let mut children = vec![inner_node.clone()]; + for j in 0..children_layers_count { + let mut new_children = Vec::new(); + for child in children { + for i in 0..children_count { + let node = Node::new(1000000 * (j + 1) + i); + new_children.push(node.clone()); + child.add_child_unchecked(&ctx, node.clone()); + } + } + children = new_children; + } + println!("Setup children: {:?}", start.elapsed()); + + print_aggregation_numbers(inner_node.clone()); + + let start = Instant::now(); + for i in 0..additional_root_count { + let node = Node::new(2 + i); + roots.push(node.clone()); + aggregation_data(&ctx, &NodeRef(node.clone())).active = true; + node.add_child_unchecked(&ctx, inner_node.clone()); + } + println!("Setup additional roots: {:?}", start.elapsed()); + + print_aggregation_numbers(inner_node.clone()); + + // Add another root + let start = Instant::now(); + { + let node = Node::new(1); + roots.push(node.clone()); + aggregation_data(&ctx, &NodeRef(node.clone())).active = true; + node.add_child_unchecked(&ctx, inner_node.clone()); + } + let root_duration = start.elapsed(); + println!("Root: {:?}", root_duration); + + // Add another child + let start = Instant::now(); + { + let node = Node::new(999999); + inner_node.add_child_unchecked(&ctx, node.clone()); + } + let child_duration = start.elapsed(); + println!("Child: {:?}", child_duration); + + print_aggregation_numbers(inner_node.clone()); + + assert!(root_duration.as_micros() < 10000); + assert!(child_duration.as_micros() < 10000); + + // check_invariants(&ctx, roots.iter().cloned().map(NodeRef)); +} + +#[test] +fn many_children() { + let something_with_lifetime = 0; + let ctx = NodeAggregationContext { + additions: AtomicU32::new(0), + something_with_lifetime: &something_with_lifetime, + add_value: false, + }; + let mut roots: Vec> = Vec::new(); + let mut children: Vec> = Vec::new(); + const CHILDREN: u32 = 100000; + const ROOTS: u32 = 3; + let inner_node = Node::new(0); + let start = Instant::now(); + for i in 0..ROOTS { + let node = Node::new(10000 + i); + roots.push(node.clone()); + aggregation_data(&ctx, &NodeRef(node.clone())).active = true; + node.add_child_unchecked(&ctx, inner_node.clone()); + } + println!("Roots: {:?}", start.elapsed()); + let start = Instant::now(); + for i in 0..CHILDREN { + let node = Node::new(20000 + i); + children.push(node.clone()); + inner_node.add_child_unchecked(&ctx, node.clone()); + } + println!("Children: {:?}", start.elapsed()); + let start = Instant::now(); + for i in 0..CHILDREN { + let node = Node::new(40000 + i); + children.push(node.clone()); + inner_node.add_child_unchecked(&ctx, node.clone()); + } + let children_duration = start.elapsed(); + println!("Children: {:?}", children_duration); + let mut number_of_slow_children = 0; + for j in 0..10 { + let start = Instant::now(); + for i in 0..CHILDREN { + let node = Node::new(50000 + j * 10000 + i); + children.push(node.clone()); + inner_node.add_child_unchecked(&ctx, node.clone()); + } + let dur = start.elapsed(); + println!("Children: {:?}", dur); + if dur > children_duration * 2 { + number_of_slow_children += 1; + } + } + + let start = Instant::now(); + for i in 0..ROOTS { + let node = Node::new(30000 + i); + roots.push(node.clone()); + aggregation_data(&ctx, &NodeRef(node.clone())).active = true; + node.add_child_unchecked(&ctx, inner_node.clone()); + } + println!("Roots: {:?}", start.elapsed()); + + // Technically it should always be 0, but the performance of the environment + // might vary so we accept a few slow children + assert!(number_of_slow_children < 3); + + check_invariants(&ctx, roots.iter().cloned().map(NodeRef)); + + // let root = NodeRef(roots[0].clone()); + // print(&ctx, &root, false); +} + +#[test] +fn concurrent_modification() { + let something_with_lifetime = 0; + let ctx = NodeAggregationContext { + additions: AtomicU32::new(0), + something_with_lifetime: &something_with_lifetime, + add_value: true, + }; + let root1 = Node::new(1); + let root2 = Node::new(2); + let helper = Node::new(3); + let inner_node = Node::new(10); + let outer_node1 = Node::new(11); + let outer_node2 = Node::new(12); + let outer_node3 = Node::new(13); + let outer_node4 = Node::new(14); + inner_node.add_child(&ctx, outer_node1.clone()); + inner_node.add_child(&ctx, outer_node2.clone()); + root2.add_child(&ctx, helper.clone()); + outer_node1.prepare_aggregation_number(&ctx, 7).apply(&ctx); + outer_node3.prepare_aggregation_number(&ctx, 7).apply(&ctx); + root1.prepare_aggregation_number(&ctx, 8).apply(&ctx); + root2.prepare_aggregation_number(&ctx, 4).apply(&ctx); + helper.prepare_aggregation_number(&ctx, 3).apply(&ctx); + + let add_job1 = root1.prepare_add_child(&ctx, inner_node.clone()); + let add_job2 = inner_node.prepare_add_child(&ctx, outer_node3.clone()); + let add_job3 = inner_node.prepare_add_child(&ctx, outer_node4.clone()); + let add_job4 = helper.prepare_add_child(&ctx, inner_node.clone()); + + add_job4.apply(&ctx); + print_all(&ctx, [root1.clone(), root2.clone()].map(NodeRef), true); + add_job3.apply(&ctx); + print_all(&ctx, [root1.clone(), root2.clone()].map(NodeRef), true); + add_job2.apply(&ctx); + print_all(&ctx, [root1.clone(), root2.clone()].map(NodeRef), true); + add_job1.apply(&ctx); + + print_all(&ctx, [root1.clone(), root2.clone()].map(NodeRef), true); + + check_invariants(&ctx, [root1, root2].map(NodeRef)); +} + +#[test] +fn fuzzy_new() { + for size in [10, 50, 100, 200, 1000] { + for _ in 0..100 { + let seed = rand::random(); + println!("Seed {} Size {}", seed, size); + fuzzy(seed, size); + } + } +} + +#[rstest] +#[case::a(4059591975, 10)] +#[case::b(603692396, 100)] +#[case::c(3317876847, 10)] +#[case::d(4012518846, 50)] +fn fuzzy(#[case] seed: u32, #[case] count: u32) { + let something_with_lifetime = 0; + let ctx = NodeAggregationContext { + additions: AtomicU32::new(0), + something_with_lifetime: &something_with_lifetime, + add_value: true, + }; + + let mut seed_buffer = [0; 32]; + seed_buffer[0..4].copy_from_slice(&seed.to_be_bytes()); + let mut r = SmallRng::from_seed(seed_buffer); + let mut nodes = Vec::new(); + for i in 0..count { + nodes.push(Node::new(i)); + } + prepare_aggregation_data(&ctx, &NodeRef(nodes[0].clone())); + + let mut edges = IndexSet::new(); + + for _ in 0..1000 { + match r.gen_range(0..=2) { + 0 | 1 => { + // if x == 47 { + // print_all(&ctx, nodes.iter().cloned().map(NodeRef), true); + // } + // add edge + let parent = r.gen_range(0..nodes.len() - 1); + let child = r.gen_range(parent + 1..nodes.len()); + // println!("add edge {} -> {}", parent, child); + if edges.insert((parent, child)) { + nodes[parent].add_child(&ctx, nodes[child].clone()); + } + } + 2 => { + // remove edge + if edges.is_empty() { + continue; + } + let i = r.gen_range(0..edges.len()); + let (parent, child) = edges.swap_remove_index(i).unwrap(); + // println!("remove edge {} -> {}", parent, child); + nodes[parent].remove_child(&ctx, &nodes[child]); + } + _ => unreachable!(), + } + } + + for (parent, child) in edges { + nodes[parent].remove_child(&ctx, &nodes[child]); + } + + assert_eq!(aggregation_data(&ctx, &NodeRef(nodes[0].clone())).value, 0); + + check_invariants(&ctx, nodes.iter().cloned().map(NodeRef)); + + for node in nodes { + let lock = node.inner.lock(); + if let AggregationNode::Aggegating(a) = &lock.aggregation_node { + assert_eq!(a.data.value, lock.value as i32); + } + } +} + +fn print(aggregation_context: &NodeAggregationContext<'_>, root: &NodeRef, show_internal: bool) { + print_all(aggregation_context, once(root.clone()), show_internal); +} + +fn print_all( + aggregation_context: &NodeAggregationContext<'_>, + nodes: impl IntoIterator, + show_internal: bool, +) { + println!("digraph {{"); + print_graph(aggregation_context, nodes, show_internal, |item| { + let lock = item.0.inner.lock(); + if let AggregationNode::Aggegating(a) = &lock.aggregation_node { + format!("#{} [{}]", lock.value, a.data.value) + } else { + format!("#{}", lock.value) + } + }); + println!("\n}}"); +} diff --git a/crates/turbo-tasks-memory/src/aggregation/uppers.rs b/crates/turbo-tasks-memory/src/aggregation/uppers.rs new file mode 100644 index 0000000000000..e131962cd280b --- /dev/null +++ b/crates/turbo-tasks-memory/src/aggregation/uppers.rs @@ -0,0 +1,256 @@ +use super::{ + balance_queue::BalanceQueue, + in_progress::start_in_progress_count, + optimize::{optimize_aggregation_number_for_uppers, MAX_UPPERS}, + AggegatingNode, AggregationContext, AggregationNode, AggregationNodeGuard, + PreparedInternalOperation, PreparedOperation, StackVec, +}; +use crate::count_hash_set::RemovePositiveCountResult; + +/// Adds an upper node to a node. Returns the number of affected nodes by this +/// operation. This will also propagate the followers to the new upper node. +pub fn add_upper( + ctx: &C, + balance_queue: &mut BalanceQueue, + node: C::Guard<'_>, + node_id: &C::NodeRef, + upper_id: &C::NodeRef, + already_optimizing_for_upper: bool, +) -> usize { + add_upper_count( + ctx, + balance_queue, + node, + node_id, + upper_id, + 1, + already_optimizing_for_upper, + ) + .affected_nodes +} + +pub struct AddUpperCountResult { + pub new_count: isize, + pub affected_nodes: usize, +} + +/// Adds an upper node to a node with a given count. Returns the new count of +/// the upper node and the number of affected nodes by this operation. This will +/// also propagate the followers to the new upper node. +pub fn add_upper_count( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut node: C::Guard<'_>, + node_id: &C::NodeRef, + upper_id: &C::NodeRef, + count: usize, + already_optimizing_for_upper: bool, +) -> AddUpperCountResult { + // TODO add_clonable_count could return the current count for better performance + let (added, count) = match &mut *node { + AggregationNode::Leaf { uppers, .. } => { + if uppers.add_clonable_count(upper_id, count) { + let count = uppers.get_count(upper_id); + (true, count) + } else { + (false, uppers.get_count(upper_id)) + } + } + AggregationNode::Aggegating(aggegating) => { + let AggegatingNode { ref mut uppers, .. } = **aggegating; + if uppers.add_clonable_count(upper_id, count) { + let count = uppers.get_count(upper_id); + (true, count) + } else { + (false, uppers.get_count(upper_id)) + } + } + }; + let mut affected_nodes = 0; + if added { + affected_nodes = on_added( + ctx, + balance_queue, + node, + node_id, + upper_id, + already_optimizing_for_upper, + ); + } else { + drop(node); + } + AddUpperCountResult { + new_count: count, + affected_nodes, + } +} + +/// Called when an upper node was added to a node. This will propagate the +/// followers to the new upper node. +pub fn on_added( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut node: C::Guard<'_>, + node_id: &C::NodeRef, + upper_id: &C::NodeRef, + already_optimizing_for_upper: bool, +) -> usize { + let uppers = node.uppers(); + let uppers_len = uppers.len(); + let optimize = (!already_optimizing_for_upper + && uppers_len > MAX_UPPERS + && (uppers_len - MAX_UPPERS).count_ones() == 1) + .then(|| (true, uppers.iter().cloned().collect::>())); + let (add_change, followers) = match &mut *node { + AggregationNode::Leaf { .. } => { + let add_change = node.get_add_change(); + let children = node.children().collect::>(); + start_in_progress_count(ctx, upper_id, children.len() as u32); + drop(node); + (add_change, children) + } + AggregationNode::Aggegating(aggegating) => { + let AggegatingNode { ref followers, .. } = **aggegating; + let add_change = ctx.data_to_add_change(&aggegating.data); + let followers = followers.iter().cloned().collect::>(); + start_in_progress_count(ctx, upper_id, followers.len() as u32); + drop(node); + + (add_change, followers) + } + }; + + let mut optimizing = false; + + // This heuristic ensures that we don’t have too many upper edges, which would + // degrade update performance + if let Some((leaf, uppers)) = optimize { + optimizing = + optimize_aggregation_number_for_uppers(ctx, balance_queue, node_id, leaf, uppers); + } + + let mut affected_nodes = 0; + + // Make sure to propagate the change to the upper node + let mut upper = ctx.node(upper_id); + let add_prepared = add_change.and_then(|add_change| upper.apply_change(ctx, add_change)); + affected_nodes += followers.len(); + let prepared = followers + .into_iter() + .filter_map(|child_id| { + upper.notify_new_follower(ctx, balance_queue, upper_id, &child_id, optimizing) + }) + .collect::>(); + drop(upper); + add_prepared.apply(ctx); + for prepared in prepared { + affected_nodes += prepared.apply(ctx, balance_queue); + } + + affected_nodes +} + +/// Removes an upper node from a node with a count. +pub fn remove_upper_count( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut node: C::Guard<'_>, + upper_id: &C::NodeRef, + count: usize, +) { + let removed = match &mut *node { + AggregationNode::Leaf { uppers, .. } => uppers.remove_clonable_count(upper_id, count), + AggregationNode::Aggegating(aggegating) => { + let AggegatingNode { ref mut uppers, .. } = **aggegating; + uppers.remove_clonable_count(upper_id, count) + } + }; + if removed { + on_removed(ctx, balance_queue, node, upper_id); + } +} + +pub struct RemovePositiveUpperCountResult { + pub removed_count: usize, + pub remaining_count: isize, +} + +/// Removes a positive count of an upper node from a node. +/// Returns the removed count and the remaining count of the upper node. +/// This will also propagate the followers to the removed upper node. +pub fn remove_positive_upper_count( + ctx: &C, + balance_queue: &mut BalanceQueue, + mut node: C::Guard<'_>, + upper_id: &C::NodeRef, + count: usize, +) -> RemovePositiveUpperCountResult { + let RemovePositiveCountResult { + removed, + removed_count, + count, + } = match &mut *node { + AggregationNode::Leaf { uppers, .. } => { + uppers.remove_positive_clonable_count(upper_id, count) + } + AggregationNode::Aggegating(aggegating) => { + let AggegatingNode { ref mut uppers, .. } = **aggegating; + uppers.remove_positive_clonable_count(upper_id, count) + } + }; + if removed { + on_removed(ctx, balance_queue, node, upper_id); + } + RemovePositiveUpperCountResult { + removed_count, + remaining_count: count, + } +} + +/// Called when an upper node was removed from a node. This will propagate the +/// followers to the removed upper node. +pub fn on_removed( + ctx: &C, + balance_queue: &mut BalanceQueue, + node: C::Guard<'_>, + upper_id: &C::NodeRef, +) { + match &*node { + AggregationNode::Leaf { .. } => { + let remove_change = node.get_remove_change(); + let children = node.children().collect::>(); + drop(node); + let mut upper = ctx.node(upper_id); + let remove_prepared = + remove_change.and_then(|remove_change| upper.apply_change(ctx, remove_change)); + start_in_progress_count(ctx, upper_id, children.len() as u32); + let prepared = children + .into_iter() + .map(|child_id| upper.notify_lost_follower(ctx, balance_queue, upper_id, &child_id)) + .collect::>(); + drop(upper); + remove_prepared.apply(ctx); + prepared.apply(ctx, balance_queue); + } + AggregationNode::Aggegating(aggegating) => { + let remove_change = ctx.data_to_remove_change(&aggegating.data); + let followers = aggegating + .followers + .iter() + .cloned() + .collect::>(); + drop(node); + let mut upper = ctx.node(upper_id); + let remove_prepared = + remove_change.and_then(|remove_change| upper.apply_change(ctx, remove_change)); + start_in_progress_count(ctx, upper_id, followers.len() as u32); + let prepared = followers + .into_iter() + .map(|child_id| upper.notify_lost_follower(ctx, balance_queue, upper_id, &child_id)) + .collect::>(); + drop(upper); + remove_prepared.apply(ctx); + prepared.apply(ctx, balance_queue); + } + } +} diff --git a/crates/turbo-tasks-memory/src/aggregation_tree/bottom_connection.rs b/crates/turbo-tasks-memory/src/aggregation_tree/bottom_connection.rs deleted file mode 100644 index 2efd9e71cf144..0000000000000 --- a/crates/turbo-tasks-memory/src/aggregation_tree/bottom_connection.rs +++ /dev/null @@ -1,346 +0,0 @@ -use std::{hash::Hash, ops::ControlFlow, sync::Arc}; - -use auto_hash_map::{map::RawEntry, AutoMap}; -use nohash_hasher::{BuildNoHashHasher, IsEnabled}; - -use super::{ - bottom_tree::BottomTree, - inner_refs::{BottomRef, ChildLocation}, - AggregationContext, StackVec, -}; - -struct BottomRefInfo { - count: isize, - distance: u8, -} - -/// A map that stores references to bottom trees which a specific distance. It -/// stores the minimum distance added to the map. -/// -/// This is used to store uppers of leafs or smaller bottom trees with the -/// current distance. The distance is imporant to keep the correct connectivity. -#[derive(Default)] -pub struct DistanceCountMap { - map: AutoMap>, -} - -impl DistanceCountMap { - pub fn new() -> Self { - Self { - map: AutoMap::with_hasher(), - } - } - - pub fn is_unset(&self) -> bool { - self.map.is_empty() - } - - pub fn iter(&self) -> impl Iterator { - self.map - .iter() - .filter(|(_, info)| info.count > 0) - .map(|(item, &BottomRefInfo { distance, .. })| (item, distance)) - } - - pub fn add_clonable(&mut self, item: &T, distance: u8) -> bool { - match self.map.raw_entry_mut(item) { - RawEntry::Occupied(mut e) => { - let info = e.get_mut(); - info.count += 1; - match info.count.cmp(&0) { - std::cmp::Ordering::Equal => { - e.remove(); - } - std::cmp::Ordering::Greater => { - if distance < info.distance { - info.distance = distance; - } - } - std::cmp::Ordering::Less => { - // We only track that for negative count tracking and no - // need to update the distance, it would reset anyway - // once we reach 0. - } - } - false - } - RawEntry::Vacant(e) => { - e.insert(item.clone(), BottomRefInfo { count: 1, distance }); - true - } - } - } - - pub fn remove_clonable(&mut self, item: &T) -> bool { - match self.map.raw_entry_mut(item) { - RawEntry::Occupied(mut e) => { - let info = e.get_mut(); - info.count -= 1; - if info.count == 0 { - e.remove(); - true - } else { - false - } - } - RawEntry::Vacant(e) => { - e.insert( - item.clone(), - BottomRefInfo { - count: -1, - distance: 0, - }, - ); - false - } - } - } - - pub fn into_counts(self) -> impl Iterator { - self.map.into_iter().map(|(item, info)| (item, info.count)) - } - - pub fn len(&self) -> usize { - self.map.len() - } -} - -/// Connection to upper bottom trees. It has two modes: A single bottom tree, -/// where the current left/smaller bottom tree is the left-most child. Or -/// multiple bottom trees, where the current left/smaller bottom tree is an -/// inner child (not left-most). -pub enum BottomConnection { - Left(Arc>), - Inner(DistanceCountMap>), -} - -impl BottomConnection { - pub fn new() -> Self { - Self::Inner(DistanceCountMap::new()) - } - - pub fn is_unset(&self) -> bool { - match self { - Self::Left(_) => false, - Self::Inner(list) => list.is_unset(), - } - } - - pub fn as_cloned_uppers(&self) -> BottomUppers { - match self { - Self::Left(upper) => BottomUppers::Left(upper.clone()), - Self::Inner(upper) => BottomUppers::Inner( - upper - .iter() - .map(|(item, distance)| (item.clone(), distance)) - .collect(), - ), - } - } - - #[must_use] - pub fn set_left_upper( - &mut self, - upper: &Arc>, - ) -> DistanceCountMap> { - match std::mem::replace(self, BottomConnection::Left(upper.clone())) { - BottomConnection::Left(_) => unreachable!("Can't have two left children"), - BottomConnection::Inner(old_inner) => old_inner, - } - } - - pub fn unset_left_upper(&mut self, upper: &Arc>) { - match std::mem::replace(self, BottomConnection::Inner(DistanceCountMap::new())) { - BottomConnection::Left(old_upper) => { - debug_assert!(Arc::ptr_eq(&old_upper, upper)); - } - BottomConnection::Inner(_) => unreachable!("Must that a left child"), - } - } -} - -impl BottomConnection { - pub fn child_change>( - &self, - aggregation_context: &C, - change: &C::ItemChange, - ) { - match self { - BottomConnection::Left(upper) => { - upper.child_change(aggregation_context, change); - } - BottomConnection::Inner(list) => { - for (BottomRef { upper }, _) in list.iter() { - upper.child_change(aggregation_context, change); - } - } - } - } - - pub fn get_root_info>( - &self, - aggregation_context: &C, - root_info_type: &C::RootInfoType, - mut result: C::RootInfo, - ) -> C::RootInfo { - match &self { - BottomConnection::Left(upper) => { - let info = upper.get_root_info(aggregation_context, root_info_type); - if aggregation_context.merge_root_info(&mut result, info) == ControlFlow::Break(()) - { - return result; - } - } - BottomConnection::Inner(list) => { - for (BottomRef { upper }, _) in list.iter() { - let info = upper.get_root_info(aggregation_context, root_info_type); - if aggregation_context.merge_root_info(&mut result, info) - == ControlFlow::Break(()) - { - return result; - } - } - } - } - result - } -} - -pub enum BottomUppers { - Left(Arc>), - Inner(StackVec<(BottomRef, u8)>), -} - -impl BottomUppers { - pub fn add_children_of_child<'a, C: AggregationContext>( - &self, - aggregation_context: &C, - children: impl IntoIterator + Clone, - ) where - I: 'a, - { - match self { - BottomUppers::Left(upper) => { - upper.add_children_of_child(aggregation_context, ChildLocation::Left, children, 0); - } - BottomUppers::Inner(list) => { - for &(BottomRef { ref upper }, nesting_level) in list { - upper.add_children_of_child( - aggregation_context, - ChildLocation::Inner, - children.clone(), - nesting_level + 1, - ); - } - } - } - } - - pub fn add_child_of_child>( - &self, - aggregation_context: &C, - child_of_child: &I, - ) { - match self { - BottomUppers::Left(upper) => { - upper.add_child_of_child( - aggregation_context, - ChildLocation::Left, - child_of_child, - 0, - ); - } - BottomUppers::Inner(list) => { - for &(BottomRef { ref upper }, nesting_level) in list.iter() { - upper.add_child_of_child( - aggregation_context, - ChildLocation::Inner, - child_of_child, - nesting_level + 1, - ); - } - } - } - } - - pub fn remove_child_of_child>( - &self, - aggregation_context: &C, - child_of_child: &I, - ) { - match self { - BottomUppers::Left(upper) => { - upper.remove_child_of_child(aggregation_context, child_of_child); - } - BottomUppers::Inner(list) => { - for (BottomRef { upper }, _) in list { - upper.remove_child_of_child(aggregation_context, child_of_child); - } - } - } - } - - pub fn remove_children_of_child<'a, C: AggregationContext>( - &self, - aggregation_context: &C, - children: impl IntoIterator + Clone, - ) where - I: 'a, - { - match self { - BottomUppers::Left(upper) => { - upper.remove_children_of_child(aggregation_context, children); - } - BottomUppers::Inner(list) => { - for (BottomRef { upper }, _) in list { - upper.remove_children_of_child(aggregation_context, children.clone()); - } - } - } - } - - pub fn child_change>( - &self, - aggregation_context: &C, - change: &C::ItemChange, - ) { - match self { - BottomUppers::Left(upper) => { - upper.child_change(aggregation_context, change); - } - BottomUppers::Inner(list) => { - for (BottomRef { upper }, _) in list { - upper.child_change(aggregation_context, change); - } - } - } - } - - pub fn get_root_info>( - &self, - aggregation_context: &C, - root_info_type: &C::RootInfoType, - mut result: C::RootInfo, - ) -> C::RootInfo { - match &self { - BottomUppers::Left(upper) => { - let info = upper.get_root_info(aggregation_context, root_info_type); - if aggregation_context.merge_root_info(&mut result, info) == ControlFlow::Break(()) - { - return result; - } - } - BottomUppers::Inner(list) => { - for (BottomRef { upper }, _) in list.iter() { - let info = upper.get_root_info(aggregation_context, root_info_type); - if aggregation_context.merge_root_info(&mut result, info) - == ControlFlow::Break(()) - { - return result; - } - } - } - } - result - } -} diff --git a/crates/turbo-tasks-memory/src/aggregation_tree/bottom_tree.rs b/crates/turbo-tasks-memory/src/aggregation_tree/bottom_tree.rs deleted file mode 100644 index 35b9472c7f4a9..0000000000000 --- a/crates/turbo-tasks-memory/src/aggregation_tree/bottom_tree.rs +++ /dev/null @@ -1,730 +0,0 @@ -use std::{hash::Hash, ops::ControlFlow, sync::Arc}; - -use nohash_hasher::{BuildNoHashHasher, IsEnabled}; -use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; -use ref_cast::RefCast; - -use super::{ - bottom_connection::BottomConnection, - inner_refs::{BottomRef, ChildLocation, TopRef}, - leaf::{ - add_inner_upper_to_item, bottom_tree, remove_inner_upper_from_item, - remove_left_upper_from_item, - }, - top_tree::TopTree, - AggregationContext, StackVec, CHILDREN_INNER_THRESHOLD, CONNECTIVITY_LIMIT, -}; -use crate::count_hash_set::{CountHashSet, RemoveIfEntryResult}; - -/// The bottom half of the aggregation tree. It aggregates items up the a -/// certain connectivity depending on the "height". Every level of the tree -/// aggregates the previous level. -pub struct BottomTree { - height: u8, - item: I, - state: RwLock>, -} - -pub struct BottomTreeState { - data: T, - bottom_upper: BottomConnection, - top_upper: CountHashSet, BuildNoHashHasher>>, - // TODO can this become negative? - following: CountHashSet>, -} - -impl BottomTree { - pub fn new(item: I, height: u8) -> Self { - Self { - height, - item, - state: RwLock::new(BottomTreeState { - data: T::default(), - bottom_upper: BottomConnection::new(), - top_upper: CountHashSet::new(), - following: CountHashSet::new(), - }), - } - } -} - -impl BottomTree { - pub fn add_children_of_child<'a, C: AggregationContext>( - self: &Arc, - aggregation_context: &C, - child_location: ChildLocation, - children: impl IntoIterator, - nesting_level: u8, - ) where - I: 'a, - { - match child_location { - ChildLocation::Left => { - // the left child has new children - // this means it's a inner child of this node - // We always want to aggregate over at least connectivity 1 - self.add_children_of_child_inner(aggregation_context, children, nesting_level); - } - ChildLocation::Inner => { - // the inner child has new children - // this means white children are inner children of this node - // and blue children need to propagate up - let mut children = children.into_iter().collect(); - if nesting_level > CONNECTIVITY_LIMIT { - self.add_children_of_child_following(aggregation_context, children); - return; - } - - self.add_children_of_child_if_following(&mut children); - self.add_children_of_child_inner(aggregation_context, children, nesting_level); - } - } - } - - fn add_children_of_child_if_following(&self, children: &mut StackVec<&I>) { - let mut state = self.state.write(); - children.retain(|&mut child| !state.following.add_if_entry(child)); - } - - fn add_children_of_child_following>( - self: &Arc, - aggregation_context: &C, - mut children: StackVec<&I>, - ) { - let mut state = self.state.write(); - children.retain(|&mut child| state.following.add_clonable(child)); - if children.is_empty() { - return; - } - let buttom_uppers = state.bottom_upper.as_cloned_uppers(); - let top_upper = state.top_upper.iter().cloned().collect::>(); - drop(state); - for TopRef { upper } in top_upper { - upper.add_children_of_child(aggregation_context, children.iter().copied()); - } - buttom_uppers.add_children_of_child(aggregation_context, children.iter().copied()); - } - - fn add_children_of_child_inner<'a, C: AggregationContext>( - self: &Arc, - aggregation_context: &C, - children: impl IntoIterator, - nesting_level: u8, - ) where - I: 'a, - { - let mut following = StackVec::default(); - if self.height == 0 { - for child in children { - let can_be_inner = - add_inner_upper_to_item(aggregation_context, child, self, nesting_level); - if !can_be_inner { - following.push(child); - } - } - } else { - for child in children { - let can_be_inner = bottom_tree(aggregation_context, child, self.height - 1) - .add_inner_bottom_tree_upper(aggregation_context, self, nesting_level); - if !can_be_inner { - following.push(child); - } - } - } - if !following.is_empty() { - self.add_children_of_child_following(aggregation_context, following); - } - } - - pub fn add_child_of_child>( - self: &Arc, - aggregation_context: &C, - child_location: ChildLocation, - child_of_child: &I, - nesting_level: u8, - ) { - debug_assert!(child_of_child != &self.item); - match child_location { - ChildLocation::Left => { - // the left child has a new child - // this means it's a inner child of this node - // We always want to aggregate over at least connectivity 1 - self.add_child_of_child_inner(aggregation_context, child_of_child, nesting_level); - } - ChildLocation::Inner => { - if nesting_level <= CONNECTIVITY_LIMIT { - // the inner child has a new child - // but it's not a blue node and we are not too deep - // this means it's a inner child of this node - // if it's not already a following child - if !self.add_child_of_child_if_following(child_of_child) { - self.add_child_of_child_inner( - aggregation_context, - child_of_child, - nesting_level, - ); - } - } else { - // the inner child has a new child - // this means we need to propagate the change up - // and store them in our own list - self.add_child_of_child_following(aggregation_context, child_of_child); - } - } - } - } - - fn add_child_of_child_if_following(&self, child_of_child: &I) -> bool { - let mut state = self.state.write(); - state.following.add_if_entry(child_of_child) - } - - fn add_child_of_child_following>( - self: &Arc, - aggregation_context: &C, - child_of_child: &I, - ) { - let mut state = self.state.write(); - if !state.following.add_clonable(child_of_child) { - // Already connect, nothing more to do - return; - } - - propagate_new_following_to_uppers(state, aggregation_context, child_of_child); - } - - fn add_child_of_child_inner>( - self: &Arc, - aggregation_context: &C, - child_of_child: &I, - nesting_level: u8, - ) { - let can_be_inner = if self.height == 0 { - add_inner_upper_to_item(aggregation_context, child_of_child, self, nesting_level) - } else { - bottom_tree(aggregation_context, child_of_child, self.height - 1) - .add_inner_bottom_tree_upper(aggregation_context, self, nesting_level) - }; - if !can_be_inner { - self.add_child_of_child_following(aggregation_context, child_of_child); - } - } - - pub fn remove_child_of_child>( - self: &Arc, - aggregation_context: &C, - child_of_child: &I, - ) { - if !self.remove_child_of_child_if_following(aggregation_context, child_of_child) { - self.remove_child_of_child_inner(aggregation_context, child_of_child); - } - } - - pub fn remove_children_of_child<'a, C: AggregationContext>( - self: &Arc, - aggregation_context: &C, - children: impl IntoIterator, - ) where - I: 'a, - { - let mut children = children.into_iter().collect(); - self.remove_children_of_child_if_following(aggregation_context, &mut children); - self.remove_children_of_child_inner(aggregation_context, children); - } - - fn remove_child_of_child_if_following>( - self: &Arc, - aggregation_context: &C, - child_of_child: &I, - ) -> bool { - let mut state = self.state.write(); - match state.following.remove_if_entry(child_of_child) { - RemoveIfEntryResult::PartiallyRemoved => return true, - RemoveIfEntryResult::NotPresent => return false, - RemoveIfEntryResult::Removed => {} - } - propagate_lost_following_to_uppers(state, aggregation_context, child_of_child); - true - } - - fn remove_children_of_child_if_following<'a, C: AggregationContext>( - self: &Arc, - aggregation_context: &C, - children: &mut Vec<&'a I>, - ) { - let mut state = self.state.write(); - let mut removed = StackVec::default(); - children.retain(|&child| match state.following.remove_if_entry(child) { - RemoveIfEntryResult::PartiallyRemoved => false, - RemoveIfEntryResult::NotPresent => true, - RemoveIfEntryResult::Removed => { - removed.push(child); - false - } - }); - if !removed.is_empty() { - propagate_lost_followings_to_uppers(state, aggregation_context, removed); - } - } - - fn remove_child_of_child_following>( - self: &Arc, - aggregation_context: &C, - child_of_child: &I, - ) -> bool { - let mut state = self.state.write(); - if !state.following.remove_clonable(child_of_child) { - // no present, nothing to do - return false; - } - propagate_lost_following_to_uppers(state, aggregation_context, child_of_child); - true - } - - fn remove_children_of_child_following>( - self: &Arc, - aggregation_context: &C, - mut children: StackVec<&I>, - ) { - let mut state = self.state.write(); - children.retain(|&mut child| state.following.remove_clonable(child)); - propagate_lost_followings_to_uppers(state, aggregation_context, children); - } - - fn remove_child_of_child_inner>( - self: &Arc, - aggregation_context: &C, - child_of_child: &I, - ) { - let can_remove_inner = if self.height == 0 { - remove_inner_upper_from_item(aggregation_context, child_of_child, self) - } else { - bottom_tree(aggregation_context, child_of_child, self.height - 1) - .remove_inner_bottom_tree_upper(aggregation_context, self) - }; - if !can_remove_inner { - self.remove_child_of_child_following(aggregation_context, child_of_child); - } - } - - fn remove_children_of_child_inner<'a, C: AggregationContext>( - self: &Arc, - aggregation_context: &C, - children: impl IntoIterator, - ) where - I: 'a, - { - let unremoveable: StackVec<_> = if self.height == 0 { - children - .into_iter() - .filter(|&child| !remove_inner_upper_from_item(aggregation_context, child, self)) - .collect() - } else { - children - .into_iter() - .filter(|&child| { - !bottom_tree(aggregation_context, child, self.height - 1) - .remove_inner_bottom_tree_upper(aggregation_context, self) - }) - .collect() - }; - if !unremoveable.is_empty() { - self.remove_children_of_child_following(aggregation_context, unremoveable); - } - } - - pub fn add_left_bottom_tree_upper>( - &self, - aggregation_context: &C, - upper: &Arc>, - ) { - let mut state = self.state.write(); - let old_inner = state.bottom_upper.set_left_upper(upper); - let add_change = aggregation_context.info_to_add_change(&state.data); - let children = state.following.iter().cloned().collect::>(); - - let remove_change = (!old_inner.is_unset()) - .then(|| aggregation_context.info_to_remove_change(&state.data)) - .flatten(); - - drop(state); - if let Some(change) = add_change { - upper.child_change(aggregation_context, &change); - } - if !children.is_empty() { - upper.add_children_of_child( - aggregation_context, - ChildLocation::Left, - children.iter(), - 1, - ); - } - - // Convert this node into a following node for all old (inner) uppers - // - // Old state: - // I1, I2 - // \ - // self - // Adding L as new left upper: - // I1, I2 L - // \ / - // self - // Final state: (I1 and I2 have L as following instead) - // I1, I2 ----> L - // / - // self - // I1 and I2 have "self" change removed since it's now part of L instead. - // L = upper, I1, I2 = old_inner - // - for (BottomRef { upper: old_upper }, count) in old_inner.into_counts() { - let item = &self.item; - old_upper.migrate_old_inner( - aggregation_context, - item, - count, - &remove_change, - &children, - ); - } - } - - pub fn migrate_old_inner>( - self: &Arc, - aggregation_context: &C, - item: &I, - count: isize, - remove_change: &Option, - following: &[I], - ) { - let mut state = self.state.write(); - if count > 0 { - // add as following - if state.following.add_count(item.clone(), count as usize) { - propagate_new_following_to_uppers(state, aggregation_context, item); - } else { - drop(state); - } - // remove from self - if let Some(change) = remove_change.as_ref() { - self.child_change(aggregation_context, change); - } - self.remove_children_of_child(aggregation_context, following); - } else { - // remove count from following instead - if state.following.remove_count(item.clone(), -count as usize) { - propagate_lost_following_to_uppers(state, aggregation_context, item); - } - } - } - - #[must_use] - pub fn add_inner_bottom_tree_upper>( - &self, - aggregation_context: &C, - upper: &Arc>, - nesting_level: u8, - ) -> bool { - let mut state = self.state.write(); - let number_of_following = state.following.len(); - let BottomConnection::Inner(inner) = &mut state.bottom_upper else { - return false; - }; - if inner.len() * number_of_following > CHILDREN_INNER_THRESHOLD { - return false; - }; - let new = inner.add_clonable(BottomRef::ref_cast(upper), nesting_level); - if new { - if let Some(change) = aggregation_context.info_to_add_change(&state.data) { - upper.child_change(aggregation_context, &change); - } - let children = state.following.iter().cloned().collect::>(); - drop(state); - if !children.is_empty() { - upper.add_children_of_child( - aggregation_context, - ChildLocation::Inner, - &children, - nesting_level + 1, - ); - } - } - true - } - - pub fn remove_left_bottom_tree_upper>( - self: &Arc, - aggregation_context: &C, - upper: &Arc>, - ) { - let mut state = self.state.write(); - state.bottom_upper.unset_left_upper(upper); - if let Some(change) = aggregation_context.info_to_remove_change(&state.data) { - upper.child_change(aggregation_context, &change); - } - let following = state.following.iter().cloned().collect::>(); - if state.top_upper.is_empty() { - drop(state); - self.remove_self_from_lower(aggregation_context); - } else { - drop(state); - } - upper.remove_children_of_child(aggregation_context, &following); - } - - #[must_use] - pub fn remove_inner_bottom_tree_upper>( - &self, - aggregation_context: &C, - upper: &Arc>, - ) -> bool { - let mut state = self.state.write(); - let BottomConnection::Inner(inner) = &mut state.bottom_upper else { - return false; - }; - let removed = inner.remove_clonable(BottomRef::ref_cast(upper)); - if removed { - let remove_change = aggregation_context.info_to_remove_change(&state.data); - let following = state.following.iter().cloned().collect::>(); - drop(state); - if let Some(change) = remove_change { - upper.child_change(aggregation_context, &change); - } - upper.remove_children_of_child(aggregation_context, &following); - } - true - } - - pub fn add_top_tree_upper>( - &self, - aggregation_context: &C, - upper: &Arc>, - ) { - let mut state = self.state.write(); - let new = state.top_upper.add_clonable(TopRef::ref_cast(upper)); - if new { - if let Some(change) = aggregation_context.info_to_add_change(&state.data) { - upper.child_change(aggregation_context, &change); - } - for following in state.following.iter() { - upper.add_child_of_child(aggregation_context, following); - } - } - } - - #[allow(dead_code)] - pub fn remove_top_tree_upper>( - self: &Arc, - aggregation_context: &C, - upper: &Arc>, - ) { - let mut state = self.state.write(); - let removed = state.top_upper.remove_clonable(TopRef::ref_cast(upper)); - if removed { - if let Some(change) = aggregation_context.info_to_remove_change(&state.data) { - upper.child_change(aggregation_context, &change); - } - for following in state.following.iter() { - upper.remove_child_of_child(aggregation_context, following); - } - if state.top_upper.is_empty() - && !matches!(state.bottom_upper, BottomConnection::Left(_)) - { - drop(state); - self.remove_self_from_lower(aggregation_context); - } - } - } - - fn remove_self_from_lower( - self: &Arc, - aggregation_context: &impl AggregationContext, - ) { - if self.height == 0 { - remove_left_upper_from_item(aggregation_context, &self.item, self); - } else { - bottom_tree(aggregation_context, &self.item, self.height - 1) - .remove_left_bottom_tree_upper(aggregation_context, self); - } - } - - pub fn child_change>( - &self, - aggregation_context: &C, - change: &C::ItemChange, - ) { - let mut state = self.state.write(); - let change = aggregation_context.apply_change(&mut state.data, change); - let state = RwLockWriteGuard::downgrade(state); - propagate_change_to_upper(&state, aggregation_context, change); - } - - pub fn get_root_info>( - &self, - aggregation_context: &C, - root_info_type: &C::RootInfoType, - ) -> C::RootInfo { - let mut result = aggregation_context.new_root_info(root_info_type); - let top_uppers = { - let state = self.state.read(); - state.top_upper.iter().cloned().collect::>() - }; - for TopRef { upper } in top_uppers.iter() { - let info = upper.get_root_info(aggregation_context, root_info_type); - if aggregation_context.merge_root_info(&mut result, info) == ControlFlow::Break(()) { - return result; - } - } - let bottom_uppers = { - let state = self.state.read(); - state.bottom_upper.as_cloned_uppers() - }; - bottom_uppers.get_root_info(aggregation_context, root_info_type, result) - } -} - -fn propagate_lost_following_to_uppers( - state: RwLockWriteGuard<'_, BottomTreeState>, - aggregation_context: &C, - child_of_child: &C::ItemRef, -) { - let bottom_uppers = state.bottom_upper.as_cloned_uppers(); - let top_upper = state.top_upper.iter().cloned().collect::>(); - drop(state); - for TopRef { upper } in top_upper { - upper.remove_child_of_child(aggregation_context, child_of_child); - } - bottom_uppers.remove_child_of_child(aggregation_context, child_of_child); -} - -fn propagate_lost_followings_to_uppers<'a, C: AggregationContext>( - state: RwLockWriteGuard<'_, BottomTreeState>, - aggregation_context: &C, - children: impl IntoIterator + Clone, -) where - C::ItemRef: 'a, -{ - let bottom_uppers = state.bottom_upper.as_cloned_uppers(); - let top_upper = state.top_upper.iter().cloned().collect::>(); - drop(state); - for TopRef { upper } in top_upper { - upper.remove_children_of_child(aggregation_context, children.clone()); - } - bottom_uppers.remove_children_of_child(aggregation_context, children); -} - -fn propagate_new_following_to_uppers( - state: RwLockWriteGuard<'_, BottomTreeState>, - aggregation_context: &C, - child_of_child: &C::ItemRef, -) { - let bottom_uppers = state.bottom_upper.as_cloned_uppers(); - let top_upper = state.top_upper.iter().cloned().collect::>(); - drop(state); - for TopRef { upper } in top_upper { - upper.add_child_of_child(aggregation_context, child_of_child); - } - bottom_uppers.add_child_of_child(aggregation_context, child_of_child); -} - -fn propagate_change_to_upper( - state: &RwLockReadGuard>, - aggregation_context: &C, - change: Option, -) { - let Some(change) = change else { - return; - }; - state - .bottom_upper - .child_change(aggregation_context, &change); - for TopRef { upper } in state.top_upper.iter() { - upper.child_change(aggregation_context, &change); - } -} - -#[allow(clippy::disallowed_methods)] // Allow VecDeque::new() in this test -#[cfg(test)] -fn visit_graph( - aggregation_context: &C, - entry: &C::ItemRef, - height: u8, -) -> (usize, usize) { - use std::collections::{HashSet, VecDeque}; - let mut queue = VecDeque::new(); - let mut visited = HashSet::new(); - visited.insert(entry.clone()); - queue.push_back(entry.clone()); - let mut edges = 0; - while let Some(item) = queue.pop_front() { - let tree = bottom_tree(aggregation_context, &item, height); - let state = tree.state.read(); - for next in state.following.iter() { - edges += 1; - if visited.insert(next.clone()) { - queue.push_back(next.clone()); - } - } - } - (visited.len(), edges) -} - -#[allow(clippy::disallowed_methods)] // Allow VecDeque::new() in this test -#[cfg(test)] -pub fn print_graph( - aggregation_context: &C, - entry: &C::ItemRef, - height: u8, - color_upper: bool, - name_fn: impl Fn(&C::ItemRef) -> String, -) { - use std::{ - collections::{HashSet, VecDeque}, - fmt::Write, - }; - let (nodes, edges) = visit_graph(aggregation_context, entry, height); - if !color_upper { - print!("subgraph cluster_{} {{", height); - print!( - "label = \"Level {}\\n{} nodes, {} edges\";", - height, nodes, edges - ); - print!("color = \"black\";"); - } - let mut edges = String::new(); - let mut queue = VecDeque::new(); - let mut visited = HashSet::new(); - visited.insert(entry.clone()); - queue.push_back(entry.clone()); - while let Some(item) = queue.pop_front() { - let tree = bottom_tree(aggregation_context, &item, height); - let name = name_fn(&item); - let label = name.to_string(); - let state = tree.state.read(); - if color_upper { - print!(r#""{} {}" [color=red];"#, height - 1, name); - } else { - print!(r#""{} {}" [label="{}"];"#, height, name, label); - } - for next in state.following.iter() { - if !color_upper { - write!( - edges, - r#""{} {}" -> "{} {}";"#, - height, - name, - height, - name_fn(next) - ) - .unwrap(); - } - if visited.insert(next.clone()) { - queue.push_back(next.clone()); - } - } - } - if !color_upper { - println!("}}"); - println!("{}", edges); - } -} diff --git a/crates/turbo-tasks-memory/src/aggregation_tree/inner_refs.rs b/crates/turbo-tasks-memory/src/aggregation_tree/inner_refs.rs deleted file mode 100644 index 92693d65f3a7c..0000000000000 --- a/crates/turbo-tasks-memory/src/aggregation_tree/inner_refs.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::{ - hash::{Hash, Hasher}, - sync::Arc, -}; - -use nohash_hasher::IsEnabled; -use ref_cast::RefCast; - -use super::{bottom_tree::BottomTree, top_tree::TopTree}; - -#[derive(Clone, Copy, PartialEq, Eq, Hash)] -pub enum ChildLocation { - // Left-most child - Left, - // Inner child, not left-most - Inner, -} - -/// A reference to a [TopTree]. -#[derive(RefCast)] -#[repr(transparent)] -pub struct TopRef { - pub upper: Arc>, -} - -impl IsEnabled for TopRef {} - -impl Hash for TopRef { - fn hash(&self, state: &mut H) { - Arc::as_ptr(&self.upper).hash(state); - } -} - -impl PartialEq for TopRef { - fn eq(&self, other: &Self) -> bool { - Arc::ptr_eq(&self.upper, &other.upper) - } -} - -impl Eq for TopRef {} - -impl Clone for TopRef { - fn clone(&self) -> Self { - Self { - upper: self.upper.clone(), - } - } -} - -/// A reference to a [BottomTree]. -#[derive(RefCast)] -#[repr(transparent)] -pub struct BottomRef { - pub upper: Arc>, -} - -impl Hash for BottomRef { - fn hash(&self, state: &mut H) { - Arc::as_ptr(&self.upper).hash(state); - } -} - -impl IsEnabled for BottomRef {} - -impl PartialEq for BottomRef { - fn eq(&self, other: &Self) -> bool { - Arc::ptr_eq(&self.upper, &other.upper) - } -} - -impl Eq for BottomRef {} - -impl Clone for BottomRef { - fn clone(&self) -> Self { - Self { - upper: self.upper.clone(), - } - } -} diff --git a/crates/turbo-tasks-memory/src/aggregation_tree/leaf.rs b/crates/turbo-tasks-memory/src/aggregation_tree/leaf.rs deleted file mode 100644 index 0166e66d6a6e0..0000000000000 --- a/crates/turbo-tasks-memory/src/aggregation_tree/leaf.rs +++ /dev/null @@ -1,403 +0,0 @@ -use std::{hash::Hash, sync::Arc}; - -use auto_hash_map::AutoSet; -use nohash_hasher::IsEnabled; -use ref_cast::RefCast; -use tracing::Level; - -use super::{ - bottom_connection::{BottomConnection, DistanceCountMap}, - bottom_tree::BottomTree, - inner_refs::{BottomRef, ChildLocation}, - top_tree::TopTree, - AggregationContext, AggregationItemLock, LargeStackVec, CHILDREN_INNER_THRESHOLD, -}; - -/// The leaf of the aggregation tree. It's usually stored inside of the nodes -/// that should be aggregated by the aggregation tree. It caches [TopTree]s and -/// [BottomTree]s created from that node. And it also stores the upper bottom -/// trees. -pub struct AggregationTreeLeaf { - top_trees: Vec>>>, - bottom_trees: Vec>>>, - upper: BottomConnection, -} - -impl AggregationTreeLeaf { - pub fn new() -> Self { - Self { - top_trees: Vec::new(), - bottom_trees: Vec::new(), - upper: BottomConnection::new(), - } - } - - /// Prepares the addition of new children. It returns a closure that should - /// be executed outside of the leaf lock. - #[allow(unused)] - pub fn add_children_job<'a, C: AggregationContext>( - &self, - aggregation_context: &'a C, - children: Vec, - ) -> impl FnOnce() + 'a - where - I: 'a, - T: 'a, - { - let uppers = self.upper.as_cloned_uppers(); - move || { - uppers.add_children_of_child(aggregation_context, &children); - } - } - - /// Prepares the addition of a new child. It returns a closure that should - /// be executed outside of the leaf lock. - pub fn add_child_job<'a, C: AggregationContext>( - &self, - aggregation_context: &'a C, - child: &'a I, - ) -> impl FnOnce() + 'a - where - T: 'a, - { - let uppers = self.upper.as_cloned_uppers(); - move || { - uppers.add_child_of_child(aggregation_context, child); - } - } - - /// Removes a child. - pub fn remove_child>( - &self, - aggregation_context: &C, - child: &I, - ) { - self.upper - .as_cloned_uppers() - .remove_child_of_child(aggregation_context, child); - } - - /// Prepares the removal of a child. It returns a closure that should be - /// executed outside of the leaf lock. - pub fn remove_children_job< - 'a, - C: AggregationContext, - H, - const N: usize, - >( - &self, - aggregation_context: &'a C, - children: AutoSet, - ) -> impl FnOnce() + 'a - where - T: 'a, - I: 'a, - H: 'a, - { - let uppers = self.upper.as_cloned_uppers(); - move || uppers.remove_children_of_child(aggregation_context, children.iter()) - } - - /// Communicates a change on the leaf to updated aggregated nodes. Prefer - /// [Self::change_job] to avoid leaf locking. - pub fn change>( - &self, - aggregation_context: &C, - change: &C::ItemChange, - ) { - self.upper.child_change(aggregation_context, change); - } - - /// Prepares the communication of a change on the leaf to updated aggregated - /// nodes. It returns a closure that should be executed outside of the leaf - /// lock. - pub fn change_job<'a, C: AggregationContext>( - &self, - aggregation_context: &'a C, - change: C::ItemChange, - ) -> impl FnOnce() + 'a - where - I: 'a, - T: 'a, - { - let uppers = self.upper.as_cloned_uppers(); - move || { - uppers.child_change(aggregation_context, &change); - } - } - - /// Captures information about the aggregation tree roots. - pub fn get_root_info>( - &self, - aggregation_context: &C, - root_info_type: &C::RootInfoType, - ) -> C::RootInfo { - self.upper.get_root_info( - aggregation_context, - root_info_type, - aggregation_context.new_root_info(root_info_type), - ) - } - - pub fn has_upper(&self) -> bool { - !self.upper.is_unset() - } -} - -fn get_or_create_in_vec( - vec: &mut Vec>, - index: usize, - create: impl FnOnce() -> T, -) -> (&mut T, bool) { - if vec.len() <= index { - vec.resize_with(index + 1, || None); - } - let item = &mut vec[index]; - if item.is_none() { - *item = Some(create()); - (item.as_mut().unwrap(), true) - } else { - (item.as_mut().unwrap(), false) - } -} - -#[tracing::instrument(level = Level::TRACE, skip(aggregation_context, reference))] -pub fn top_tree( - aggregation_context: &C, - reference: &C::ItemRef, - depth: u8, -) -> Arc> { - let new_top_tree = { - let mut item = aggregation_context.item(reference); - let leaf = item.leaf(); - let (tree, new) = get_or_create_in_vec(&mut leaf.top_trees, depth as usize, || { - Arc::new(TopTree::new(depth)) - }); - if !new { - return tree.clone(); - } - tree.clone() - }; - let bottom_tree = bottom_tree(aggregation_context, reference, depth + 4); - bottom_tree.add_top_tree_upper(aggregation_context, &new_top_tree); - new_top_tree -} - -pub fn bottom_tree( - aggregation_context: &C, - reference: &C::ItemRef, - height: u8, -) -> Arc> { - let _span; - let new_bottom_tree; - let mut result = None; - { - let mut item = aggregation_context.item(reference); - let leaf = item.leaf(); - let (tree, new) = get_or_create_in_vec(&mut leaf.bottom_trees, height as usize, || { - Arc::new(BottomTree::new(reference.clone(), height)) - }); - if !new { - return tree.clone(); - } - new_bottom_tree = tree.clone(); - _span = (height > 2).then(|| tracing::trace_span!("bottom_tree", height).entered()); - - if height == 0 { - result = Some(add_left_upper_to_item_step_1::( - &mut item, - &new_bottom_tree, - )); - } - } - if let Some(result) = result { - add_left_upper_to_item_step_2(aggregation_context, reference, &new_bottom_tree, result); - } - if height != 0 { - bottom_tree(aggregation_context, reference, height - 1) - .add_left_bottom_tree_upper(aggregation_context, &new_bottom_tree); - } - new_bottom_tree -} - -#[must_use] -pub fn add_inner_upper_to_item( - aggregation_context: &C, - reference: &C::ItemRef, - upper: &Arc>, - nesting_level: u8, -) -> bool { - let (change, children) = { - let mut item = aggregation_context.item(reference); - let number_of_children = item.number_of_children(); - let leaf = item.leaf(); - let BottomConnection::Inner(inner) = &mut leaf.upper else { - return false; - }; - if inner.len() * number_of_children > CHILDREN_INNER_THRESHOLD { - return false; - } - let new = inner.add_clonable(BottomRef::ref_cast(upper), nesting_level); - if new { - let change = item.get_add_change(); - ( - change, - item.children() - .map(|r| r.into_owned()) - .collect::>(), - ) - } else { - return true; - } - }; - if let Some(change) = change { - upper.child_change(aggregation_context, &change); - } - if !children.is_empty() { - upper.add_children_of_child( - aggregation_context, - ChildLocation::Inner, - &children, - nesting_level + 1, - ) - } - true -} - -struct AddLeftUpperIntermediateResult( - Option, - LargeStackVec, - DistanceCountMap>, - Option, -); - -#[must_use] -fn add_left_upper_to_item_step_1( - item: &mut C::ItemLock<'_>, - upper: &Arc>, -) -> AddLeftUpperIntermediateResult { - let old_inner = item.leaf().upper.set_left_upper(upper); - let remove_change_for_old_inner = (!old_inner.is_unset()) - .then(|| item.get_remove_change()) - .flatten(); - let children = item.children().map(|r| r.into_owned()).collect(); - AddLeftUpperIntermediateResult( - item.get_add_change(), - children, - old_inner, - remove_change_for_old_inner, - ) -} - -fn add_left_upper_to_item_step_2( - aggregation_context: &C, - reference: &C::ItemRef, - upper: &Arc>, - step_1_result: AddLeftUpperIntermediateResult, -) { - let AddLeftUpperIntermediateResult(change, children, old_inner, remove_change_for_old_inner) = - step_1_result; - if let Some(change) = change { - upper.child_change(aggregation_context, &change); - } - if !children.is_empty() { - upper.add_children_of_child(aggregation_context, ChildLocation::Left, &children, 1) - } - for (BottomRef { upper: old_upper }, count) in old_inner.into_counts() { - old_upper.migrate_old_inner( - aggregation_context, - reference, - count, - &remove_change_for_old_inner, - &children, - ); - } -} - -pub fn remove_left_upper_from_item( - aggregation_context: &C, - reference: &C::ItemRef, - upper: &Arc>, -) { - let mut item = aggregation_context.item(reference); - let leaf = &mut item.leaf(); - leaf.upper.unset_left_upper(upper); - let change = item.get_remove_change(); - let children = item.children().map(|r| r.into_owned()).collect::>(); - drop(item); - if let Some(change) = change { - upper.child_change(aggregation_context, &change); - } - for child in children { - upper.remove_child_of_child(aggregation_context, &child) - } -} - -#[must_use] -pub fn remove_inner_upper_from_item( - aggregation_context: &C, - reference: &C::ItemRef, - upper: &Arc>, -) -> bool { - let mut item = aggregation_context.item(reference); - let BottomConnection::Inner(inner) = &mut item.leaf().upper else { - return false; - }; - if !inner.remove_clonable(BottomRef::ref_cast(upper)) { - // Nothing to do - return true; - } - let change = item.get_remove_change(); - let children = item - .children() - .map(|r| r.into_owned()) - .collect::>(); - drop(item); - - if let Some(change) = change { - upper.child_change(aggregation_context, &change); - } - for child in children { - upper.remove_child_of_child(aggregation_context, &child) - } - true -} - -/// Checks thresholds for an item to ensure the aggregation graph stays -/// well-formed. Run this before adding a child to an item. Returns a closure -/// that should be executed outside of the leaf lock. -pub fn ensure_thresholds<'a, C: AggregationContext>( - aggregation_context: &'a C, - item: &mut C::ItemLock<'_>, -) -> Option { - let mut result = None; - - let number_of_total_children = item.number_of_children(); - let reference = item.reference().clone(); - let leaf = item.leaf(); - if let BottomConnection::Inner(list) = &leaf.upper { - if list.len() * number_of_total_children > CHILDREN_INNER_THRESHOLD { - let (tree, new) = get_or_create_in_vec(&mut leaf.bottom_trees, 0, || { - Arc::new(BottomTree::new(reference.clone(), 0)) - }); - debug_assert!(new); - let new_bottom_tree = tree.clone(); - result = Some(( - add_left_upper_to_item_step_1::(item, &new_bottom_tree), - reference, - new_bottom_tree, - )); - } - } - result.map(|(result, reference, new_bottom_tree)| { - move || { - let _span = tracing::trace_span!("aggregation_tree::reorganize").entered(); - add_left_upper_to_item_step_2( - aggregation_context, - &reference, - &new_bottom_tree, - result, - ); - } - }) -} diff --git a/crates/turbo-tasks-memory/src/aggregation_tree/mod.rs b/crates/turbo-tasks-memory/src/aggregation_tree/mod.rs deleted file mode 100644 index fd6e81a14e6a5..0000000000000 --- a/crates/turbo-tasks-memory/src/aggregation_tree/mod.rs +++ /dev/null @@ -1,150 +0,0 @@ -//! The module implements a datastructure that aggregates a "forest" into less -//! nodes. For any node one can ask for a single aggregated version of all -//! children on that node. Changes the forest will propagate up the -//! aggregation tree to keep it up to date. So asking of an aggregated -//! information is cheap and one can even wait for aggregated info to change. -//! -//! The aggregation will try to reuse aggregated nodes on every level to reduce -//! memory and cpu usage of propagating changes. The tree structure is designed -//! for multi-thread usage. -//! -//! The aggregation tree is build out of two halfs. The top tree and the bottom -//! tree. One node of the bottom tree can aggregate items of connectivity -//! 2^height. It will do that by having bottom trees of height - 1 as children. -//! One node of the top tree can aggregate items of any connectivity. It will do -//! that by having a bottom tree of height = depth as a child and top trees of -//! depth + 1 as children. So it's basically a linked list of bottom trees of -//! increasing height. Any top or bottom node can be shared between multiple -//! parents. -//! -//! Notations: -//! - parent/child: Relationship in the original forest resp. the aggregated -//! version of the relationships. -//! - upper: Relationship to a aggregated node in a higher level (more -//! aggregated). Since all communication is strictly upwards there is no down -//! relationship for that. - -mod bottom_connection; -mod bottom_tree; -mod inner_refs; -mod leaf; -#[cfg(test)] -mod tests; -mod top_tree; - -use std::{borrow::Cow, hash::Hash, ops::ControlFlow, sync::Arc}; - -use nohash_hasher::IsEnabled; -use smallvec::SmallVec; - -use self::{leaf::top_tree, top_tree::TopTree}; -pub use self::{ - leaf::{ensure_thresholds, AggregationTreeLeaf}, - top_tree::AggregationInfoGuard, -}; - -/// The maximum connectivity of one layer of bottom tree. -const CONNECTIVITY_LIMIT: u8 = 7; - -/// The maximum of number of children muliplied by number of upper bottom trees. -/// When reached the parent of the children will form a new bottom tree. -const CHILDREN_INNER_THRESHOLD: usize = 2000; - -type StackVec = SmallVec<[I; 16]>; -type LargeStackVec = SmallVec<[I; 32]>; - -/// The context trait which defines how the aggregation tree should behave. -pub trait AggregationContext { - type ItemLock<'a>: AggregationItemLock< - ItemRef = Self::ItemRef, - Info = Self::Info, - ItemChange = Self::ItemChange, - > - where - Self: 'a; - type Info: Default; - type ItemChange; - type ItemRef: Eq + Hash + Clone + IsEnabled; - type RootInfo; - type RootInfoType; - - /// Gets mutable access to an item. - fn item<'a>(&'a self, reference: &Self::ItemRef) -> Self::ItemLock<'a>; - - /// Apply a changeset to an aggregated info object. Returns a new changeset - /// that should be applied to the next aggregation level. Might return None, - /// if no change should be applied to the next level. - fn apply_change( - &self, - info: &mut Self::Info, - change: &Self::ItemChange, - ) -> Option; - - /// Creates a changeset from an aggregated info object, that represents - /// adding the aggregated node to an aggregated node of the next level. - fn info_to_add_change(&self, info: &Self::Info) -> Option; - /// Creates a changeset from an aggregated info object, that represents - /// removing the aggregated node from an aggregated node of the next level. - fn info_to_remove_change(&self, info: &Self::Info) -> Option; - - /// Initializes a new empty root info object. - fn new_root_info(&self, root_info_type: &Self::RootInfoType) -> Self::RootInfo; - /// Creates a new root info object from an aggregated info object. This is - /// only called on the root of the aggregation tree. - fn info_to_root_info( - &self, - info: &Self::Info, - root_info_type: &Self::RootInfoType, - ) -> Self::RootInfo; - /// Merges two root info objects. Can optionally break the root info - /// gathering which will return this root info object as final result. - fn merge_root_info( - &self, - root_info: &mut Self::RootInfo, - other: Self::RootInfo, - ) -> ControlFlow<()>; -} - -/// A lock on a single item. -pub trait AggregationItemLock { - type Info; - type ItemRef: Clone + IsEnabled; - type ItemChange; - type ChildrenIter<'a>: Iterator> + 'a - where - Self: 'a; - /// Returns a reference to the item. - fn reference(&self) -> &Self::ItemRef; - /// Get mutable access to the leaf info. - fn leaf(&mut self) -> &mut AggregationTreeLeaf; - /// Returns the number of children. - fn number_of_children(&self) -> usize; - /// Returns an iterator over the children. - fn children(&self) -> Self::ChildrenIter<'_>; - /// Returns a changeset that represents the addition of the item. - fn get_add_change(&self) -> Option; - /// Returns a changeset that represents the removal of the item. - fn get_remove_change(&self) -> Option; -} - -/// Gives an reference to the root aggregated info for a given item. -pub fn aggregation_info( - aggregation_context: &C, - reference: &C::ItemRef, -) -> AggregationInfoReference { - AggregationInfoReference { - tree: top_tree(aggregation_context, reference, 0), - } -} - -/// A reference to the root aggregated info of a node. -pub struct AggregationInfoReference { - tree: Arc>, -} - -impl AggregationInfoReference { - /// Locks the info and gives mutable access to it. - pub fn lock(&self) -> AggregationInfoGuard { - self.tree.lock_info() - } -} diff --git a/crates/turbo-tasks-memory/src/aggregation_tree/tests.rs b/crates/turbo-tasks-memory/src/aggregation_tree/tests.rs deleted file mode 100644 index 53c73d1a6513c..0000000000000 --- a/crates/turbo-tasks-memory/src/aggregation_tree/tests.rs +++ /dev/null @@ -1,600 +0,0 @@ -use std::{ - borrow::Cow, - hash::Hash, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, - }, - time::Instant, -}; - -use nohash_hasher::IsEnabled; -use parking_lot::{Mutex, MutexGuard}; -use ref_cast::RefCast; - -use super::{aggregation_info, AggregationContext, AggregationItemLock, AggregationTreeLeaf}; -use crate::aggregation_tree::{bottom_tree::print_graph, leaf::ensure_thresholds}; - -struct Node { - inner: Mutex, -} - -impl Node { - fn incr(&self, aggregation_context: &NodeAggregationContext) { - let mut guard = self.inner.lock(); - guard.value += 10000; - guard - .aggregation_leaf - .change(aggregation_context, &Change { value: 10000 }); - } -} - -#[derive(Copy, Clone)] -struct Change { - value: i32, -} - -impl Change { - fn is_empty(&self) -> bool { - self.value == 0 - } -} - -struct NodeInner { - children: Vec>, - aggregation_leaf: AggregationTreeLeaf, - value: u32, -} - -struct NodeAggregationContext<'a> { - additions: AtomicU32, - #[allow(dead_code)] - something_with_lifetime: &'a u32, - add_value: bool, -} - -#[derive(Clone, RefCast)] -#[repr(transparent)] -struct NodeRef(Arc); - -impl Hash for NodeRef { - fn hash(&self, state: &mut H) { - Arc::as_ptr(&self.0).hash(state); - } -} - -impl IsEnabled for NodeRef {} - -impl PartialEq for NodeRef { - fn eq(&self, other: &Self) -> bool { - Arc::ptr_eq(&self.0, &other.0) - } -} - -impl Eq for NodeRef {} - -struct NodeGuard { - guard: MutexGuard<'static, NodeInner>, - node: Arc, -} - -impl NodeGuard { - unsafe fn new(guard: MutexGuard<'_, NodeInner>, node: Arc) -> Self { - NodeGuard { - guard: unsafe { std::mem::transmute(guard) }, - node, - } - } -} - -impl AggregationItemLock for NodeGuard { - type Info = Aggregated; - type ItemRef = NodeRef; - type ItemChange = Change; - type ChildrenIter<'a> = impl Iterator> + 'a; - - fn reference(&self) -> &Self::ItemRef { - NodeRef::ref_cast(&self.node) - } - - fn leaf(&mut self) -> &mut AggregationTreeLeaf { - &mut self.guard.aggregation_leaf - } - - fn number_of_children(&self) -> usize { - self.guard.children.len() - } - - fn children(&self) -> Self::ChildrenIter<'_> { - self.guard - .children - .iter() - .map(|child| Cow::Owned(NodeRef(child.clone()))) - } - - fn get_remove_change(&self) -> Option { - let change = Change { - value: -(self.guard.value as i32), - }; - if change.is_empty() { - None - } else { - Some(change) - } - } - - fn get_add_change(&self) -> Option { - let change = Change { - value: self.guard.value as i32, - }; - if change.is_empty() { - None - } else { - Some(change) - } - } -} - -impl<'a> AggregationContext for NodeAggregationContext<'a> { - type ItemLock<'l> = NodeGuard where Self: 'l; - type Info = Aggregated; - type ItemRef = NodeRef; - type ItemChange = Change; - - fn item<'b>(&'b self, reference: &Self::ItemRef) -> Self::ItemLock<'b> { - let r = reference.0.clone(); - let guard = reference.0.inner.lock(); - unsafe { NodeGuard::new(guard, r) } - } - - fn apply_change(&self, info: &mut Aggregated, change: &Change) -> Option { - if info.value != 0 { - self.additions.fetch_add(1, Ordering::SeqCst); - } - if self.add_value { - info.value += change.value; - } - Some(*change) - } - - fn info_to_add_change(&self, info: &Self::Info) -> Option { - let change = Change { value: info.value }; - if change.is_empty() { - None - } else { - Some(change) - } - } - - fn info_to_remove_change(&self, info: &Self::Info) -> Option { - let change = Change { value: -info.value }; - if change.is_empty() { - None - } else { - Some(change) - } - } - - type RootInfo = bool; - - type RootInfoType = (); - - fn new_root_info(&self, root_info_type: &Self::RootInfoType) -> Self::RootInfo { - #[allow(clippy::match_single_binding)] - match root_info_type { - () => false, - } - } - - fn info_to_root_info( - &self, - info: &Self::Info, - root_info_type: &Self::RootInfoType, - ) -> Self::RootInfo { - #[allow(clippy::match_single_binding)] - match root_info_type { - () => info.active, - } - } - - fn merge_root_info( - &self, - root_info: &mut Self::RootInfo, - other: Self::RootInfo, - ) -> std::ops::ControlFlow<()> { - if other { - *root_info = true; - std::ops::ControlFlow::Break(()) - } else { - std::ops::ControlFlow::Continue(()) - } - } -} - -#[derive(Default)] -struct Aggregated { - value: i32, - active: bool, -} - -#[test] -fn chain() { - let something_with_lifetime = 0; - let ctx = NodeAggregationContext { - additions: AtomicU32::new(0), - something_with_lifetime: &something_with_lifetime, - add_value: true, - }; - let leaf = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: vec![], - aggregation_leaf: AggregationTreeLeaf::new(), - value: 10000, - }), - }); - let mut current = leaf.clone(); - for i in 1..=100 { - current = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: vec![current], - aggregation_leaf: AggregationTreeLeaf::new(), - value: i, - }), - }); - } - let current = NodeRef(current); - - { - let root_info = leaf.inner.lock().aggregation_leaf.get_root_info(&ctx, &()); - assert!(!root_info); - } - - { - let aggregated = aggregation_info(&ctx, ¤t); - assert_eq!(aggregated.lock().value, 15050); - } - assert_eq!(ctx.additions.load(Ordering::SeqCst), 100); - ctx.additions.store(0, Ordering::SeqCst); - - print(&ctx, ¤t); - - { - let root_info = leaf.inner.lock().aggregation_leaf.get_root_info(&ctx, &()); - assert!(!root_info); - } - - leaf.incr(&ctx); - // The change need to propagate through 5 top trees and 5 bottom trees - assert_eq!(ctx.additions.load(Ordering::SeqCst), 6); - ctx.additions.store(0, Ordering::SeqCst); - - { - let aggregated = aggregation_info(&ctx, ¤t); - let mut aggregated = aggregated.lock(); - assert_eq!(aggregated.value, 25050); - aggregated.active = true; - } - assert_eq!(ctx.additions.load(Ordering::SeqCst), 0); - ctx.additions.store(0, Ordering::SeqCst); - - { - let root_info = leaf.inner.lock().aggregation_leaf.get_root_info(&ctx, &()); - assert!(root_info); - } - - let i = 101; - let current = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: vec![current.0], - aggregation_leaf: AggregationTreeLeaf::new(), - value: i, - }), - }); - let current = NodeRef(current); - - { - let aggregated = aggregation_info(&ctx, ¤t); - let aggregated = aggregated.lock(); - assert_eq!(aggregated.value, 25151); - } - // This should be way less the 100 to prove that we are reusing trees - assert_eq!(ctx.additions.load(Ordering::SeqCst), 1); - ctx.additions.store(0, Ordering::SeqCst); - - leaf.incr(&ctx); - // This should be less the 20 to prove that we are reusing trees - assert_eq!(ctx.additions.load(Ordering::SeqCst), 9); - ctx.additions.store(0, Ordering::SeqCst); - - { - let root_info = leaf.inner.lock().aggregation_leaf.get_root_info(&ctx, &()); - assert!(root_info); - } -} - -#[test] -fn chain_double_connected() { - let something_with_lifetime = 0; - let ctx = NodeAggregationContext { - additions: AtomicU32::new(0), - something_with_lifetime: &something_with_lifetime, - add_value: true, - }; - let leaf = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: vec![], - aggregation_leaf: AggregationTreeLeaf::new(), - value: 1, - }), - }); - let mut current = leaf.clone(); - let mut current2 = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: vec![leaf.clone()], - aggregation_leaf: AggregationTreeLeaf::new(), - value: 2, - }), - }); - for i in 3..=100 { - let new_node = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: vec![current, current2.clone()], - aggregation_leaf: AggregationTreeLeaf::new(), - value: i, - }), - }); - current = current2; - current2 = new_node; - } - let current = NodeRef(current2); - - print(&ctx, ¤t); - - { - let aggregated = aggregation_info(&ctx, ¤t); - assert_eq!(aggregated.lock().value, 8230); - } - assert_eq!(ctx.additions.load(Ordering::SeqCst), 204); - ctx.additions.store(0, Ordering::SeqCst); -} - -const RECT_SIZE: usize = 100; -const RECT_MULT: usize = 100; - -#[test] -fn rectangle_tree() { - let something_with_lifetime = 0; - let ctx = NodeAggregationContext { - additions: AtomicU32::new(0), - something_with_lifetime: &something_with_lifetime, - add_value: false, - }; - let mut nodes: Vec>> = Vec::new(); - for y in 0..RECT_SIZE { - let mut line: Vec> = Vec::new(); - for x in 0..RECT_SIZE { - let mut children = Vec::new(); - if x > 0 { - children.push(line[x - 1].clone()); - } - if y > 0 { - children.push(nodes[y - 1][x].clone()); - } - let value = (x + y * RECT_MULT) as u32; - let node = Arc::new(Node { - inner: Mutex::new(NodeInner { - children, - aggregation_leaf: AggregationTreeLeaf::new(), - value, - }), - }); - line.push(node.clone()); - } - nodes.push(line); - } - - let root = NodeRef(nodes[RECT_SIZE - 1][RECT_SIZE - 1].clone()); - - print(&ctx, &root); -} - -#[test] -fn rectangle_adding_tree() { - let something_with_lifetime = 0; - let ctx = NodeAggregationContext { - additions: AtomicU32::new(0), - something_with_lifetime: &something_with_lifetime, - add_value: false, - }; - let mut nodes: Vec>> = Vec::new(); - - fn add_child( - parent: &Arc, - node: &Arc, - aggregation_context: &NodeAggregationContext<'_>, - ) { - let node_ref = NodeRef(node.clone()); - let mut state = parent.inner.lock(); - state.children.push(node.clone()); - let job = state - .aggregation_leaf - .add_child_job(aggregation_context, &node_ref); - drop(state); - job(); - } - for y in 0..RECT_SIZE { - let mut line: Vec> = Vec::new(); - for x in 0..RECT_SIZE { - let value = (x + y * RECT_MULT) as u32; - let node = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: Vec::new(), - aggregation_leaf: AggregationTreeLeaf::new(), - value, - }), - }); - line.push(node.clone()); - if x > 0 { - let parent = &line[x - 1]; - add_child(parent, &node, &ctx); - } - if y > 0 { - let parent = &nodes[y - 1][x]; - add_child(parent, &node, &ctx); - } - if x == 0 && y == 0 { - aggregation_info(&ctx, &NodeRef(node.clone())).lock().active = true; - } - } - nodes.push(line); - } - - let root = NodeRef(nodes[0][0].clone()); - - print(&ctx, &root); -} - -#[test] -fn many_children() { - let something_with_lifetime = 0; - let ctx = NodeAggregationContext { - additions: AtomicU32::new(0), - something_with_lifetime: &something_with_lifetime, - add_value: false, - }; - let mut roots: Vec> = Vec::new(); - let mut children: Vec> = Vec::new(); - const CHILDREN: u32 = 5000; - const ROOTS: u32 = 100; - let inner_node = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: Vec::new(), - aggregation_leaf: AggregationTreeLeaf::new(), - value: 0, - }), - }); - let start = Instant::now(); - for i in 0..ROOTS { - let node = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: Vec::new(), - aggregation_leaf: AggregationTreeLeaf::new(), - value: 10000 + i, - }), - }); - roots.push(node.clone()); - aggregation_info(&ctx, &NodeRef(node.clone())).lock().active = true; - connect_child(&ctx, &node, &inner_node); - } - println!("Roots: {:?}", start.elapsed()); - let start = Instant::now(); - for i in 0..CHILDREN { - let node = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: Vec::new(), - aggregation_leaf: AggregationTreeLeaf::new(), - value: 20000 + i, - }), - }); - children.push(node.clone()); - connect_child(&ctx, &inner_node, &node); - } - println!("Children: {:?}", start.elapsed()); - let start = Instant::now(); - for i in 0..ROOTS { - let node = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: Vec::new(), - aggregation_leaf: AggregationTreeLeaf::new(), - value: 30000 + i, - }), - }); - roots.push(node.clone()); - aggregation_info(&ctx, &NodeRef(node.clone())).lock().active = true; - connect_child(&ctx, &node, &inner_node); - } - println!("Roots: {:?}", start.elapsed()); - let start = Instant::now(); - for i in 0..CHILDREN { - let node = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: Vec::new(), - aggregation_leaf: AggregationTreeLeaf::new(), - value: 40000 + i, - }), - }); - children.push(node.clone()); - connect_child(&ctx, &inner_node, &node); - } - let children_duration = start.elapsed(); - println!("Children: {:?}", children_duration); - let mut number_of_slow_children = 0; - for _ in 0..10 { - let start = Instant::now(); - for i in 0..CHILDREN { - let node = Arc::new(Node { - inner: Mutex::new(NodeInner { - children: Vec::new(), - aggregation_leaf: AggregationTreeLeaf::new(), - value: 40000 + i, - }), - }); - children.push(node.clone()); - connect_child(&ctx, &inner_node, &node); - } - let dur = start.elapsed(); - println!("Children: {:?}", dur); - if dur > children_duration * 2 { - number_of_slow_children += 1; - } - } - - // Technically it should always be 0, but the performance of the environment - // might vary so we accept a few slow children - assert!(number_of_slow_children < 3); - - let root = NodeRef(roots[0].clone()); - - print(&ctx, &root); -} - -fn connect_child( - aggregation_context: &NodeAggregationContext<'_>, - parent: &Arc, - child: &Arc, -) { - let state = parent.inner.lock(); - let node_ref = NodeRef(child.clone()); - let mut node_guard = unsafe { NodeGuard::new(state, parent.clone()) }; - while let Some(job) = ensure_thresholds(aggregation_context, &mut node_guard) { - drop(node_guard); - job(); - node_guard = unsafe { NodeGuard::new(parent.inner.lock(), parent.clone()) }; - } - let NodeGuard { - guard: mut state, .. - } = node_guard; - state.children.push(child.clone()); - let job = state - .aggregation_leaf - .add_child_job(aggregation_context, &node_ref); - drop(state); - job(); -} - -fn print(aggregation_context: &NodeAggregationContext<'_>, current: &NodeRef) { - println!("digraph {{"); - let start = 0; - let end = 3; - for i in start..end { - print_graph(aggregation_context, current, i, false, |item| { - format!("{}", item.0.inner.lock().value) - }); - } - for i in start + 1..end + 1 { - print_graph(aggregation_context, current, i, true, |item| { - format!("{}", item.0.inner.lock().value) - }); - } - println!("\n}}"); -} diff --git a/crates/turbo-tasks-memory/src/aggregation_tree/top_tree.rs b/crates/turbo-tasks-memory/src/aggregation_tree/top_tree.rs deleted file mode 100644 index da21c55d4d4c4..0000000000000 --- a/crates/turbo-tasks-memory/src/aggregation_tree/top_tree.rs +++ /dev/null @@ -1,180 +0,0 @@ -use std::{mem::transmute, ops::ControlFlow, sync::Arc}; - -use parking_lot::{Mutex, MutexGuard}; -use ref_cast::RefCast; - -use super::{inner_refs::TopRef, leaf::top_tree, AggregationContext}; -use crate::count_hash_set::CountHashSet; - -/// The top half of the aggregation tree. It can aggregate all nodes of a -/// subgraph. To do that it used a [BottomTree] of a specific height and, since -/// a bottom tree only aggregates up to a specific connectivity, also another -/// TopTree of the current depth + 1. This continues recursively until all nodes -/// are aggregated. -pub struct TopTree { - pub depth: u8, - state: Mutex>, -} - -struct TopTreeState { - data: T, - upper: CountHashSet>, -} - -impl TopTree { - pub fn new(depth: u8) -> Self { - Self { - depth, - state: Mutex::new(TopTreeState { - data: T::default(), - upper: CountHashSet::new(), - }), - } - } -} - -impl TopTree { - pub fn add_children_of_child<'a, C: AggregationContext>( - self: &Arc, - aggregation_context: &C, - children: impl IntoIterator, - ) where - C::ItemRef: 'a, - { - for child in children { - top_tree(aggregation_context, child, self.depth + 1) - .add_upper(aggregation_context, self); - } - } - - pub fn add_child_of_child>( - self: &Arc, - aggregation_context: &C, - child_of_child: &C::ItemRef, - ) { - top_tree(aggregation_context, child_of_child, self.depth + 1) - .add_upper(aggregation_context, self); - } - - pub fn remove_child_of_child>( - self: &Arc, - aggregation_context: &C, - child_of_child: &C::ItemRef, - ) { - top_tree(aggregation_context, child_of_child, self.depth + 1) - .remove_upper(aggregation_context, self); - } - - pub fn remove_children_of_child<'a, C: AggregationContext>( - self: &Arc, - aggregation_context: &C, - children: impl IntoIterator, - ) where - C::ItemRef: 'a, - { - for child in children { - top_tree(aggregation_context, child, self.depth + 1) - .remove_upper(aggregation_context, self); - } - } - - pub fn add_upper>( - &self, - aggregation_context: &C, - upper: &Arc>, - ) { - let mut state = self.state.lock(); - if state.upper.add_clonable(TopRef::ref_cast(upper)) { - if let Some(change) = aggregation_context.info_to_add_change(&state.data) { - upper.child_change(aggregation_context, &change); - } - } - } - - pub fn remove_upper>( - &self, - aggregation_context: &C, - upper: &Arc>, - ) { - let mut state = self.state.lock(); - if state.upper.remove_clonable(TopRef::ref_cast(upper)) { - if let Some(change) = aggregation_context.info_to_remove_change(&state.data) { - upper.child_change(aggregation_context, &change); - } - } - } - - pub fn child_change>( - &self, - aggregation_context: &C, - change: &C::ItemChange, - ) { - let mut state = self.state.lock(); - let change = aggregation_context.apply_change(&mut state.data, change); - propagate_change_to_upper(&state, aggregation_context, change); - } - - pub fn get_root_info>( - &self, - aggregation_context: &C, - root_info_type: &C::RootInfoType, - ) -> C::RootInfo { - let state = self.state.lock(); - if self.depth == 0 { - // This is the root - aggregation_context.info_to_root_info(&state.data, root_info_type) - } else { - let mut result = aggregation_context.new_root_info(root_info_type); - for TopRef { upper } in state.upper.iter() { - let info = upper.get_root_info(aggregation_context, root_info_type); - if aggregation_context.merge_root_info(&mut result, info) == ControlFlow::Break(()) - { - break; - } - } - result - } - } - - pub fn lock_info(self: &Arc) -> AggregationInfoGuard { - AggregationInfoGuard { - // SAFETY: We can cast the lifetime as we keep a strong reference to the tree. - // The order of the field in the struct is important to drop guard before tree. - guard: unsafe { transmute(self.state.lock()) }, - tree: self.clone(), - } - } -} - -fn propagate_change_to_upper( - state: &MutexGuard>, - aggregation_context: &C, - change: Option, -) { - let Some(change) = change else { - return; - }; - for TopRef { upper } in state.upper.iter() { - upper.child_change(aggregation_context, &change); - } -} - -pub struct AggregationInfoGuard { - guard: MutexGuard<'static, TopTreeState>, - #[allow(dead_code, reason = "need to stay alive until the guard is dropped")] - tree: Arc>, -} - -impl std::ops::Deref for AggregationInfoGuard { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.guard.data - } -} - -impl std::ops::DerefMut for AggregationInfoGuard { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.guard.data - } -} diff --git a/crates/turbo-tasks-memory/src/count_hash_set.rs b/crates/turbo-tasks-memory/src/count_hash_set.rs index 5e72c31d43524..b118a4c3c0f00 100644 --- a/crates/turbo-tasks-memory/src/count_hash_set.rs +++ b/crates/turbo-tasks-memory/src/count_hash_set.rs @@ -1,5 +1,6 @@ use std::{ borrow::Borrow, + cmp::Ordering, collections::hash_map::RandomState, fmt::{Debug, Formatter}, hash::{BuildHasher, Hash}, @@ -45,6 +46,16 @@ impl Default for CountHashSet { } } +impl FromIterator for CountHashSet { + fn from_iter>(iter: I) -> Self { + let mut set = CountHashSet::default(); + for item in iter { + set.add(item); + } + set + } +} + impl CountHashSet { pub fn new() -> Self { Self::default() @@ -71,6 +82,12 @@ pub enum RemoveIfEntryResult { NotPresent, } +pub struct RemovePositiveCountResult { + pub removed: bool, + pub removed_count: usize, + pub count: isize, +} + impl CountHashSet { /// Returns true, when the value has become visible from outside pub fn add_count(&mut self, item: T, count: usize) -> bool { @@ -135,41 +152,6 @@ impl CountHashSet { } } - /// Returns true when the value is no longer visible from outside - pub fn remove_count(&mut self, item: T, count: usize) -> bool { - if count == 0 { - return false; - } - match self.inner.entry(item) { - Entry::Occupied(mut e) => { - let value = e.get_mut(); - let old = *value; - *value -= count as isize; - if *value > 0 { - // It was and still is positive - false - } else if *value == 0 { - // It was positive and has become zero - e.remove(); - true - } else if old > 0 { - // It was positive and is negative now - self.negative_entries += 1; - true - } else { - // It was and still is negative - false - } - } - Entry::Vacant(e) => { - // It was zero and is negative now - e.insert(-(count as isize)); - self.negative_entries += 1; - false - } - } - } - /// Removes an item if it is present. pub fn remove_if_entry(&mut self, item: &T) -> RemoveIfEntryResult { match self.inner.raw_entry_mut(item) { @@ -197,6 +179,13 @@ impl CountHashSet { count: self.inner.len() - self.negative_entries, } } + + pub fn get_count(&self, item: &T) -> isize { + match self.inner.get(item) { + Some(value) => *value, + None => 0, + } + } } impl CountHashSet { @@ -275,9 +264,71 @@ impl CountHashSet { } } - /// Returns true, when the value is no longer visible from outside - pub fn remove_clonable(&mut self, item: &T) -> bool { - self.remove_clonable_count(item, 1) + /// Returns true when the value is no longer visible from outside + pub fn remove_positive_clonable_count( + &mut self, + item: &T, + count: usize, + ) -> RemovePositiveCountResult { + if count == 0 { + return RemovePositiveCountResult { + removed: false, + removed_count: 0, + count: self.inner.get(item).copied().unwrap_or(0), + }; + } + match self.inner.raw_entry_mut(item) { + RawEntry::Occupied(mut e) => { + let value = e.get_mut(); + let old = *value; + match old.cmp(&(count as isize)) { + Ordering::Less => { + if old < 0 { + // It's already negative, can't remove anything + RemovePositiveCountResult { + removed: false, + removed_count: 0, + count: old, + } + } else { + // It's removed completely with count remaining + e.remove(); + RemovePositiveCountResult { + removed: true, + removed_count: old as usize, + count: 0, + } + } + } + Ordering::Equal => { + // It's perfectly removed + e.remove(); + RemovePositiveCountResult { + removed: true, + removed_count: count, + count: 0, + } + } + Ordering::Greater => { + // It's partially removed + *value -= count as isize; + RemovePositiveCountResult { + removed: false, + removed_count: count, + count: *value, + } + } + } + } + RawEntry::Vacant(_) => { + // It's not present + RemovePositiveCountResult { + removed: false, + removed_count: 0, + count: 0, + } + } + } } } @@ -334,19 +385,19 @@ mod tests { assert_eq!(set.len(), 2); assert!(!set.is_empty()); - assert!(set.remove_count(2, 2)); + assert!(set.remove_clonable_count(&2, 2)); assert_eq!(set.len(), 1); assert!(!set.is_empty()); - assert!(!set.remove_count(2, 1)); + assert!(!set.remove_clonable_count(&2, 1)); assert_eq!(set.len(), 1); assert!(!set.is_empty()); - assert!(!set.remove_count(1, 1)); + assert!(!set.remove_clonable_count(&1, 1)); assert_eq!(set.len(), 1); assert!(!set.is_empty()); - assert!(set.remove_count(1, 1)); + assert!(set.remove_clonable_count(&1, 1)); assert_eq!(set.len(), 0); assert!(set.is_empty()); @@ -366,15 +417,15 @@ mod tests { assert_eq!(set.len(), 0); assert!(set.is_empty()); - assert!(set.add_clonable(&1)); + assert!(set.add_clonable_count(&1, 1)); assert_eq!(set.len(), 1); assert!(!set.is_empty()); - assert!(!set.add_clonable(&1)); + assert!(!set.add_clonable_count(&1, 1)); assert_eq!(set.len(), 1); assert!(!set.is_empty()); - assert!(set.add_clonable(&2)); + assert!(set.add_clonable_count(&2, 1)); assert_eq!(set.len(), 2); assert!(!set.is_empty()); @@ -442,7 +493,7 @@ mod tests { assert_eq!(set.len(), 0); assert!(set.is_empty()); - assert!(!set.remove_count(1, 0)); + assert!(!set.remove_clonable_count(&1, 0)); assert_eq!(set.len(), 0); assert!(set.is_empty()); @@ -454,7 +505,7 @@ mod tests { assert_eq!(set.len(), 0); assert!(set.is_empty()); - assert!(!set.remove_count(1, 1)); + assert!(!set.remove_clonable_count(&1, 1)); assert_eq!(set.len(), 0); assert!(set.is_empty()); diff --git a/crates/turbo-tasks-memory/src/lib.rs b/crates/turbo-tasks-memory/src/lib.rs index 57c491a4554a9..c0fecc3796963 100644 --- a/crates/turbo-tasks-memory/src/lib.rs +++ b/crates/turbo-tasks-memory/src/lib.rs @@ -7,7 +7,7 @@ #![feature(impl_trait_in_assoc_type)] #![deny(unsafe_op_in_unsafe_fn)] -mod aggregation_tree; +mod aggregation; mod cell; mod concurrent_priority_queue; mod count_hash_set; diff --git a/crates/turbo-tasks-memory/src/memory_backend.rs b/crates/turbo-tasks-memory/src/memory_backend.rs index ef4d4cdf56e73..d153b5ded677c 100644 --- a/crates/turbo-tasks-memory/src/memory_backend.rs +++ b/crates/turbo-tasks-memory/src/memory_backend.rs @@ -188,12 +188,14 @@ impl MemoryBackend { } Entry::Occupied(entry) => { // Safety: We have a fresh task id that nobody knows about yet + let task_id = *entry.get(); + drop(entry); unsafe { self.memory_tasks.remove(*new_id); let new_id = Unused::new_unchecked(new_id); turbo_tasks.reuse_task_id(new_id); } - *entry.get() + task_id } }; self.connect_task_child(parent_task, result_task, turbo_tasks); @@ -211,11 +213,15 @@ impl MemoryBackend { K: Borrow + Hash + Eq, Q: Hash + Eq + ?Sized, { - task_cache.get(key).map(|task| { - self.connect_task_child(parent_task, *task, turbo_tasks); - - *task - }) + task_cache + .get(key) + // Avoid holding the lock for too long + .map(|task_ref| *task_ref) + .map(|task_id| { + self.connect_task_child(parent_task, task_id, turbo_tasks); + + task_id + }) } pub(crate) fn schedule_when_dirty_from_aggregation( diff --git a/crates/turbo-tasks-memory/src/task.rs b/crates/turbo-tasks-memory/src/task.rs index 856a5ad8dc2ce..d8a11cdbf0d5d 100644 --- a/crates/turbo-tasks-memory/src/task.rs +++ b/crates/turbo-tasks-memory/src/task.rs @@ -1,20 +1,14 @@ -mod aggregation; -mod meta_state; -mod stats; - use std::{ borrow::Cow, cell::RefCell, cmp::{max, Reverse}, collections::{HashMap, HashSet}, - fmt::{ - Debug, Display, Formatter, {self}, - }, + fmt::{self, Debug, Display, Formatter}, future::Future, hash::{BuildHasherDefault, Hash}, mem::{replace, take}, pin::Pin, - sync::Arc, + sync::{atomic::AtomicU32, Arc}, time::{Duration, Instant}, }; @@ -35,7 +29,10 @@ use turbo_tasks::{ }; use crate::{ - aggregation_tree::{aggregation_info, ensure_thresholds, AggregationInfoGuard}, + aggregation::{ + aggregation_data, handle_new_edge, prepare_aggregation_data, query_root_info, + AggregationDataGuard, PreparedOperation, + }, cell::Cell, gc::{to_exp_u8, GcPriority, GcStats, GcTaskState}, output::{Output, OutputContent}, @@ -47,6 +44,10 @@ use crate::{ pub type NativeTaskFuture = Pin> + Send>>; pub type NativeTaskFn = Box NativeTaskFuture + Send + Sync>; +mod aggregation; +mod meta_state; +mod stats; + #[derive(Hash, Copy, Clone, PartialEq, Eq)] pub enum TaskDependency { Output(TaskId), @@ -138,6 +139,8 @@ pub struct Task { /// The mutable state of the task /// Unset state is equal to a Dirty task that has not been executed yet state: RwLock, + /// Atomic in progress counter for graph modification + graph_modification_in_progress_counter: AtomicU32, } impl Debug for Task { @@ -164,7 +167,7 @@ impl Debug for Task { /// The full state of a [Task], it includes all information. struct TaskState { - aggregation_leaf: TaskAggregationTreeLeaf, + aggregation_node: TaskAggregationNode, // TODO using a Atomic might be possible here /// More flags of task state, where not all combinations are possible. @@ -200,7 +203,7 @@ impl TaskState { stats_type: StatsType, ) -> Self { Self { - aggregation_leaf: TaskAggregationTreeLeaf::new(), + aggregation_node: TaskAggregationNode::new(), state_type: Dirty { event: Event::new(move || format!("TaskState({})::event", description())), outdated_dependencies: Default::default(), @@ -223,7 +226,7 @@ impl TaskState { stats_type: StatsType, ) -> Self { Self { - aggregation_leaf: TaskAggregationTreeLeaf::new(), + aggregation_node: TaskAggregationNode::new(), state_type: Scheduled { event: Event::new(move || format!("TaskState({})::event", description())), outdated_dependencies: Default::default(), @@ -249,13 +252,13 @@ impl TaskState { /// but is still attached to parents and aggregated. struct PartialTaskState { stats_type: StatsType, - aggregation_leaf: TaskAggregationTreeLeaf, + aggregation_leaf: TaskAggregationNode, } impl PartialTaskState { fn into_full(self, description: impl Fn() -> String + Send + Sync + 'static) -> TaskState { TaskState { - aggregation_leaf: self.aggregation_leaf, + aggregation_node: self.aggregation_leaf, state_type: Dirty { event: Event::new(move || format!("TaskState({})::event", description())), outdated_dependencies: Default::default(), @@ -289,7 +292,7 @@ fn test_unloaded_task_state_size() { impl UnloadedTaskState { fn into_full(self, description: impl Fn() -> String + Send + Sync + 'static) -> TaskState { TaskState { - aggregation_leaf: TaskAggregationTreeLeaf::new(), + aggregation_node: TaskAggregationNode::new(), state_type: Dirty { event: Event::new(move || format!("TaskState({})::event", description())), outdated_dependencies: Default::default(), @@ -307,7 +310,7 @@ impl UnloadedTaskState { fn into_partial(self) -> PartialTaskState { PartialTaskState { - aggregation_leaf: TaskAggregationTreeLeaf::new(), + aggregation_leaf: TaskAggregationNode::new(), stats_type: self.stats_type, } } @@ -418,7 +421,7 @@ enum TaskStateType { use TaskStateType::*; use self::{ - aggregation::{Aggregated, RootInfoType, RootType, TaskAggregationTreeLeaf, TaskGuard}, + aggregation::{ActiveQuery, RootType, TaskAggregationNode, TaskGuard}, meta_state::{ FullTaskWriteGuard, TaskMetaState, TaskMetaStateReadGuard, TaskMetaStateWriteGuard, }, @@ -439,6 +442,7 @@ impl Task { description, stats_type, )))), + graph_modification_in_progress_counter: AtomicU32::new(0), } } @@ -456,6 +460,7 @@ impl Task { description, stats_type, )))), + graph_modification_in_progress_counter: AtomicU32::new(0), } } @@ -473,6 +478,7 @@ impl Task { description, stats_type, )))), + graph_modification_in_progress_counter: AtomicU32::new(0), } } @@ -493,7 +499,7 @@ impl Task { { Self::set_root_type( &aggregation_context, - &mut aggregation_context.aggregation_info(id).lock(), + &mut aggregation_context.aggregation_data(id), RootType::Root, ); } @@ -507,19 +513,15 @@ impl Task { ) { let mut aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); { - let aggregation_info = &aggregation_context.aggregation_info(id); - Self::set_root_type( - &aggregation_context, - &mut aggregation_info.lock(), - RootType::Once, - ); + let mut aggregation_guard = aggregation_context.aggregation_data(id); + Self::set_root_type(&aggregation_context, &mut aggregation_guard, RootType::Once); } aggregation_context.apply_queued_updates(); } fn set_root_type( aggregation_context: &TaskAggregationContext, - aggregation: &mut AggregationInfoGuard, + aggregation: &mut AggregationDataGuard>, root_type: RootType, ) { aggregation.root_type = Some(root_type); @@ -540,7 +542,7 @@ impl Task { ) { let mut aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); { - aggregation_context.aggregation_info(id).lock().root_type = None; + aggregation_context.aggregation_data(id).root_type = None; } aggregation_context.apply_queued_updates(); } @@ -628,10 +630,8 @@ impl Task { } TaskDependency::Collectibles(task, trait_type) => { let aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); - let aggregation = aggregation_context.aggregation_info(task); - aggregation - .lock() - .remove_collectible_dependent_task(trait_type, reader); + let mut aggregation = aggregation_context.aggregation_data(task); + aggregation.remove_collectible_dependent_task(trait_type, reader); } } } @@ -887,24 +887,22 @@ impl Task { } } let change_job = state - .aggregation_leaf - .change_job(&aggregation_context, change); + .aggregation_node + .apply_change(&aggregation_context, change); #[cfg(feature = "lazy_remove_children")] let remove_job = if outdated_children.is_empty() { None } else { - Some( - state - .aggregation_leaf - .remove_children_job(&aggregation_context, outdated_children), - ) + Some(state.aggregation_node.handle_lost_edges( + &aggregation_context, + &self.id, + outdated_children, + )) }; drop(state); - change_job(); + change_job.apply(&aggregation_context); #[cfg(feature = "lazy_remove_children")] - if let Some(job) = remove_job { - job(); - } + remove_job.apply(&aggregation_context); } aggregation_context.apply_queued_updates(); } @@ -1014,18 +1012,16 @@ impl Task { change.collectibles.push((trait_type, value, -count)); } } - change_job = Some( - state - .aggregation_leaf - .change_job(&aggregation_context, change), - ); + change_job = state + .aggregation_node + .apply_change(&aggregation_context, change); } #[cfg(feature = "lazy_remove_children")] if !outdated_children.is_empty() { - remove_job = Some( - state - .aggregation_leaf - .remove_children_job(&aggregation_context, outdated_children), + remove_job = state.aggregation_node.handle_lost_edges( + &aggregation_context, + &self.id, + outdated_children, ); } event.notify(usize::MAX); @@ -1049,20 +1045,13 @@ impl Task { if !dependencies.is_empty() { self.clear_dependencies(dependencies, backend, turbo_tasks); } - if let Some(job) = change_job { - job(); - } + change_job.apply(&aggregation_context); #[cfg(feature = "lazy_remove_children")] - if let Some(job) = remove_job { - job(); - } + remove_job.apply(&aggregation_context); } if let TaskType::Once(_) = self.ty { // unset the root type, so tasks below are no longer active - aggregation_context - .aggregation_info(self.id) - .lock() - .root_type = None; + aggregation_context.aggregation_data(self.id).root_type = None; } aggregation_context.apply_queued_updates(); @@ -1088,12 +1077,14 @@ impl Task { return; } + let mut aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); + let active = query_root_info(&aggregation_context, ActiveQuery::default(), self.id); + let state = if force_schedule { TaskMetaStateWriteGuard::Full(self.full_state_mut()) } else { self.state_mut() }; - let mut aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); if let TaskMetaStateWriteGuard::Full(mut state) = state { match state.state_type { Scheduled { .. } | InProgressDirty { .. } => { @@ -1112,14 +1103,15 @@ impl Task { }), outdated_dependencies: take(outdated_dependencies), }; - state.aggregation_leaf.change( + let change_job = state.aggregation_node.apply_change( &aggregation_context, - &TaskChange { + TaskChange { dirty_tasks_update: vec![(self.id, -1)], ..Default::default() }, ); drop(state); + change_job.apply(&aggregation_context); turbo_tasks.schedule(self.id); } else { // already dirty @@ -1129,51 +1121,20 @@ impl Task { Done { ref mut dependencies, } => { - let mut has_set_unfinished = false; let outdated_dependencies = take(dependencies); // add to dirty lists and potentially schedule let description = self.get_event_description(); - let should_schedule = force_schedule - || state - .aggregation_leaf - .get_root_info(&aggregation_context, &RootInfoType::IsActive) - || { - state.aggregation_leaf.change( - &aggregation_context, - &TaskChange { - unfinished: 1, - #[cfg(feature = "track_unfinished")] - unfinished_tasks_update: vec![(self.id, 1)], - dirty_tasks_update: vec![(self.id, 1)], - ..Default::default() - }, - ); - has_set_unfinished = true; - if aggregation_context.take_scheduled_dirty_task(self.id) { - state.aggregation_leaf.change( - &aggregation_context, - &TaskChange { - dirty_tasks_update: vec![(self.id, -1)], - ..Default::default() - }, - ); - true - } else { - false - } - }; - if !has_set_unfinished { - state.aggregation_leaf.change( + let should_schedule = force_schedule || active; + if should_schedule { + let change_job = state.aggregation_node.apply_change( &aggregation_context, - &TaskChange { + TaskChange { unfinished: 1, #[cfg(feature = "track_unfinished")] unfinished_tasks_update: vec![(self.id, 1)], ..Default::default() }, ); - } - if should_schedule { state.state_type = Scheduled { event: Event::new(move || { format!("TaskState({})::event", description()) @@ -1181,12 +1142,23 @@ impl Task { outdated_dependencies, }; drop(state); + change_job.apply(&aggregation_context); if cfg!(feature = "print_task_invalidation") { println!("invalidated Task {{ id: {}, name: {} }}", *self.id, self.ty); } turbo_tasks.schedule(self.id); } else { + let change_job = state.aggregation_node.apply_change( + &aggregation_context, + TaskChange { + unfinished: 1, + #[cfg(feature = "track_unfinished")] + unfinished_tasks_update: vec![(self.id, 1)], + dirty_tasks_update: vec![(self.id, 1)], + ..Default::default() + }, + ); state.state_type = Dirty { event: Event::new(move || { format!("TaskState({})::event", description()) @@ -1194,6 +1166,7 @@ impl Task { outdated_dependencies, }; drop(state); + change_job.apply(&aggregation_context); } } InProgress { @@ -1229,22 +1202,22 @@ impl Task { } else { None }; - let change_job = change.map(|change| { + let change_job = change.and_then(|change| { state - .aggregation_leaf - .change_job(&aggregation_context, change) + .aggregation_node + .apply_change(&aggregation_context, change) }); #[cfg(feature = "lazy_remove_children")] - let remove_job = state - .aggregation_leaf - .remove_children_job(&aggregation_context, outdated_children); + let remove_job = state.aggregation_node.handle_lost_edges( + &aggregation_context, + &self.id, + outdated_children, + ); state.state_type = InProgressDirty { event }; drop(state); - if let Some(job) = change_job { - job(); - } + change_job.apply(&aggregation_context); #[cfg(feature = "lazy_remove_children")] - remove_job(); + remove_job.apply(&aggregation_context); } } } @@ -1256,18 +1229,18 @@ impl Task { backend: &MemoryBackend, turbo_tasks: &dyn TurboTasksBackendApi, ) { - let aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); let mut state = self.full_state_mut(); if let TaskStateType::Dirty { ref mut event, ref mut outdated_dependencies, } = state.state_type { + let mut aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); state.state_type = Scheduled { event: event.take(), outdated_dependencies: take(outdated_dependencies), }; - let job = state.aggregation_leaf.change_job( + let job = state.aggregation_node.apply_change( &aggregation_context, TaskChange { dirty_tasks_update: vec![(self.id, -1)], @@ -1276,7 +1249,8 @@ impl Task { ); drop(state); turbo_tasks.schedule(self.id); - job(); + job.apply(&aggregation_context); + aggregation_context.apply_queued_updates(); } } @@ -1500,21 +1474,7 @@ impl Task { { let mut add_job = None; { - let mut guard = TaskGuard { - id: self.id, - guard: self.state_mut(), - }; - while let Some(thresholds_job) = ensure_thresholds(&aggregation_context, &mut guard) - { - drop(guard); - thresholds_job(); - guard = TaskGuard { - id: self.id, - guard: self.state_mut(), - }; - } - let TaskGuard { guard, .. } = guard; - let mut state = TaskMetaStateWriteGuard::full_from(guard.into_inner(), self); + let mut state = self.full_state_mut(); if state.children.insert(child_id) { #[cfg(feature = "lazy_remove_children")] if let TaskStateType::InProgress { @@ -1525,11 +1485,15 @@ impl Task { return; } } - add_job = Some( - state - .aggregation_leaf - .add_child_job(&aggregation_context, &child_id), - ); + let number_of_children = state.children.len(); + let mut guard = TaskGuard::from_full(self.id, state); + add_job = Some(handle_new_edge( + &aggregation_context, + &mut guard, + &self.id, + &child_id, + number_of_children, + )); } } if let Some(job) = add_job { @@ -1539,22 +1503,14 @@ impl Task { // So it's fine to ignore the race condition existing here. backend.with_task(child_id, |child| { if child.is_dirty() { - let state = self.state(); - let active = match state { - TaskMetaStateReadGuard::Full(state) => state - .aggregation_leaf - .get_root_info(&aggregation_context, &RootInfoType::IsActive), - TaskMetaStateReadGuard::Partial(state) => state - .aggregation_leaf - .get_root_info(&aggregation_context, &RootInfoType::IsActive), - TaskMetaStateReadGuard::Unloaded => false, - }; + let active = + query_root_info(&aggregation_context, ActiveQuery::default(), self.id); if active { child.schedule_when_dirty_from_aggregation(backend, turbo_tasks); } } }); - job(); + job.apply(&aggregation_context); } } aggregation_context.apply_queued_updates(); @@ -1569,34 +1525,35 @@ impl Task { turbo_tasks: &dyn TurboTasksBackendApi, ) -> Result> { let mut aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); - let aggregation_when_strongly_consistent = - strongly_consistent.then(|| aggregation_info(&aggregation_context, &self.id)); - let mut state = self.full_state_mut(); - if let Some(aggregation) = aggregation_when_strongly_consistent { - { - let mut aggregation = aggregation.lock(); - if aggregation.unfinished > 0 { - if aggregation.root_type.is_none() { - Self::set_root_type( - &aggregation_context, - &mut aggregation, - RootType::ReadingStronglyConsistent, - ); - } - let listener = aggregation.unfinished_event.listen_with_note(note); - drop(aggregation); - drop(state); - aggregation_context.apply_queued_updates(); - - return Ok(Err(listener)); - } else if matches!( - aggregation.root_type, - Some(RootType::ReadingStronglyConsistent) - ) { - aggregation.root_type = None; + if strongly_consistent { + prepare_aggregation_data(&aggregation_context, &self.id); + } + let mut state = if strongly_consistent { + let mut aggregation = aggregation_data(&aggregation_context, &self.id); + if aggregation.unfinished > 0 { + if aggregation.root_type.is_none() { + Self::set_root_type( + &aggregation_context, + &mut aggregation, + RootType::ReadingStronglyConsistent, + ); } + let listener = aggregation.unfinished_event.listen_with_note(note); + drop(aggregation); + aggregation_context.apply_queued_updates(); + + return Ok(Err(listener)); + } else if matches!( + aggregation.root_type, + Some(RootType::ReadingStronglyConsistent) + ) { + aggregation.root_type = None; } - } + let state = aggregation.into_inner().into_inner().into_inner(); + TaskMetaStateWriteGuard::full_from(state, self) + } else { + self.full_state_mut() + }; let result = match state.state_type { Done { .. } => { let result = func(&mut state.output)?; @@ -1615,14 +1572,15 @@ impl Task { event, outdated_dependencies: take(outdated_dependencies), }; - state.aggregation_leaf.change( + let change_job = state.aggregation_node.apply_change( &aggregation_context, - &TaskChange { + TaskChange { dirty_tasks_update: vec![(self.id, -1)], ..Default::default() }, ); drop(state); + change_job.apply(&aggregation_context); Ok(Err(listener)) } Scheduled { ref event, .. } @@ -1645,10 +1603,8 @@ impl Task { turbo_tasks: &dyn TurboTasksBackendApi, ) -> AutoMap { let aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); - aggregation_context - .aggregation_info(id) - .lock() - .read_collectibles(trait_type, reader) + let mut aggregation_data = aggregation_context.aggregation_data(id); + aggregation_data.read_collectibles(trait_type, reader) } pub(crate) fn emit_collectible( @@ -1661,14 +1617,15 @@ impl Task { let mut aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); let mut state = self.full_state_mut(); state.collectibles.emit(trait_type, collectible); - state.aggregation_leaf.change( + let change_job = state.aggregation_node.apply_change( &aggregation_context, - &TaskChange { + TaskChange { collectibles: vec![(trait_type, collectible, 1)], ..Default::default() }, ); drop(state); + change_job.apply(&aggregation_context); aggregation_context.apply_queued_updates(); } @@ -1683,14 +1640,15 @@ impl Task { let mut aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); let mut state = self.full_state_mut(); state.collectibles.unemit(trait_type, collectible, count); - state.aggregation_leaf.change( + let change_job = state.aggregation_node.apply_change( &aggregation_context, - &TaskChange { + TaskChange { collectibles: vec![(trait_type, collectible, -(count as i32))], ..Default::default() }, ); drop(state); + change_job.apply(&aggregation_context); aggregation_context.apply_queued_updates(); } @@ -1739,6 +1697,9 @@ impl Task { }) } + let aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); + let active = query_root_info(&aggregation_context, ActiveQuery::default(), self.id); + if let TaskMetaStateWriteGuard::Full(mut state) = self.state_mut() { if state.stateful { stats.no_gc_possible += 1; @@ -1758,20 +1719,9 @@ impl Task { } // Check if the task need to be activated again - let active = if state.gc.inactive { - let active = state.aggregation_leaf.get_root_info( - &TaskAggregationContext::new(turbo_tasks, backend), - &RootInfoType::IsActive, - ); - if active { - state.gc.inactive = false; - true - } else { - false - } - } else { - true - }; + if state.gc.inactive && active { + state.gc.inactive = false; + } let last_duration = state.stats.last_duration(); let compute_duration = last_duration.into(); @@ -1999,9 +1949,11 @@ impl Task { backend: &MemoryBackend, turbo_tasks: &dyn TurboTasksBackendApi, ) -> bool { + let mut aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); let mut clear_dependencies = None; + let mut change_job = None; let TaskState { - ref mut aggregation_leaf, + aggregation_node: ref mut aggregation_leaf, ref mut state_type, .. } = *full_state; @@ -2009,28 +1961,14 @@ impl Task { Done { ref mut dependencies, } => { - let mut aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); - aggregation_leaf.change( - &TaskAggregationContext::new(turbo_tasks, backend), - &TaskChange { + change_job = aggregation_leaf.apply_change( + &aggregation_context, + TaskChange { unfinished: 1, dirty_tasks_update: vec![(self.id, 1)], ..Default::default() }, ); - if aggregation_context.take_scheduled_dirty_task(self.id) { - // Unloading is only possible for inactive tasks. - // We need to abort the unloading, so revert changes done so far. - aggregation_leaf.change( - &TaskAggregationContext::new(turbo_tasks, backend), - &TaskChange { - unfinished: -1, - dirty_tasks_update: vec![(self.id, -1)], - ..Default::default() - }, - ); - return false; - } clear_dependencies = Some(take(dependencies)); } Dirty { @@ -2066,7 +2004,7 @@ impl Task { cells, output, collectibles, - aggregation_leaf, + mut aggregation_node, stats, // can be dropped as it will be recomputed on next execution stateful: _, @@ -2078,33 +2016,30 @@ impl Task { gc: _, } = old_state.into_full().unwrap(); - let aggregation_context = TaskAggregationContext::new(turbo_tasks, backend); - // Remove all children, as they will be added again when this task is executed // again - if !children.is_empty() { - for child in children { - aggregation_leaf.remove_child(&aggregation_context, &child); - } - } + let remove_job = (!children.is_empty()) + .then(|| aggregation_node.handle_lost_edges(&aggregation_context, &self.id, children)); // Remove all collectibles, as they will be added again when this task is // executed again. - if let Some(collectibles) = collectibles.into_inner() { - aggregation_leaf.change( + let collectibles_job = if let Some(collectibles) = collectibles.into_inner() { + aggregation_node.apply_change( &aggregation_context, - &TaskChange { + TaskChange { collectibles: collectibles .into_iter() .map(|((t, r), c)| (t, r, -c)) .collect(), ..Default::default() }, - ); - } + ) + } else { + None + }; // TODO aggregation_leaf - let unset = !aggregation_leaf.has_upper(); + let unset = false; let stats_type = match stats { TaskStats::Essential(_) => StatsType::Essential, @@ -2114,12 +2049,16 @@ impl Task { *state = TaskMetaState::Unloaded(UnloadedTaskState { stats_type }); } else { *state = TaskMetaState::Partial(Box::new(PartialTaskState { - aggregation_leaf, + aggregation_leaf: aggregation_node, stats_type, })); } drop(state); + change_job.apply(&aggregation_context); + remove_job.apply(&aggregation_context); + collectibles_job.apply(&aggregation_context); + // Notify everyone that is listening on our output or cells. // This will mark everyone as dirty and will trigger a new execution when they // become active again. @@ -2135,6 +2074,8 @@ impl Task { self.clear_dependencies(dependencies, backend, turbo_tasks); } + aggregation_context.apply_queued_updates(); + true } } diff --git a/crates/turbo-tasks-memory/src/task/aggregation.rs b/crates/turbo-tasks-memory/src/task/aggregation.rs index 92749f364fd42..088beb941379e 100644 --- a/crates/turbo-tasks-memory/src/task/aggregation.rs +++ b/crates/turbo-tasks-memory/src/task/aggregation.rs @@ -1,7 +1,9 @@ use std::{ - borrow::Cow, + cmp::Ordering, hash::{BuildHasher, Hash}, mem::take, + ops::{ControlFlow, Deref, DerefMut}, + sync::atomic::AtomicU32, }; use auto_hash_map::{map::Entry, AutoMap}; @@ -9,11 +11,14 @@ use nohash_hasher::BuildNoHashHasher; use parking_lot::Mutex; use turbo_tasks::{event::Event, RawVc, TaskId, TaskIdSet, TraitTypeId, TurboTasksBackendApi}; -use super::{meta_state::TaskMetaStateWriteGuard, TaskStateType}; +use super::{ + meta_state::{FullTaskWriteGuard, TaskMetaStateWriteGuard}, + TaskStateType, +}; use crate::{ - aggregation_tree::{ - aggregation_info, AggregationContext, AggregationInfoReference, AggregationItemLock, - AggregationTreeLeaf, + aggregation::{ + aggregation_data, AggregationContext, AggregationDataGuard, AggregationNode, + AggregationNodeGuard, RootQuery, }, MemoryBackend, }; @@ -36,10 +41,6 @@ impl CollectiblesInfo { } } -pub enum RootInfoType { - IsActive, -} - pub struct Aggregated { /// The number of unfinished items in the lower aggregation level. /// Unfinished means not "Done". @@ -159,14 +160,6 @@ impl<'a> TaskAggregationContext<'a> { } } - pub fn take_scheduled_dirty_task(&mut self, task: TaskId) -> bool { - let dirty_task_to_schedule = self.dirty_tasks_to_schedule.get_mut(); - dirty_task_to_schedule - .as_mut() - .map(|t| t.remove(&task)) - .unwrap_or(false) - } - pub fn apply_queued_updates(&mut self) { { let mut _span = None; @@ -191,8 +184,8 @@ impl<'a> TaskAggregationContext<'a> { } } - pub fn aggregation_info(&self, id: TaskId) -> AggregationInfoReference { - aggregation_info(self, &id) + pub fn aggregation_data(&self, id: TaskId) -> AggregationDataGuard> { + aggregation_data(self, &id) } } @@ -215,26 +208,31 @@ impl<'a> Drop for TaskAggregationContext<'a> { } impl<'a> AggregationContext for TaskAggregationContext<'a> { - type ItemLock<'l> = TaskGuard<'l> where Self: 'l; - type Info = Aggregated; - type ItemChange = TaskChange; - type ItemRef = TaskId; - type RootInfo = bool; - type RootInfoType = RootInfoType; - - fn item<'b>(&'b self, reference: &TaskId) -> Self::ItemLock<'b> { + type Guard<'l> = TaskGuard<'l> where Self: 'l; + type Data = Aggregated; + type DataChange = TaskChange; + type NodeRef = TaskId; + + fn node<'b>(&'b self, reference: &TaskId) -> Self::Guard<'b> { let task = self.backend.task(*reference); - TaskGuard { - id: *reference, - guard: task.state_mut(), - } + TaskGuard::new(*reference, task.state_mut()) + } + + fn atomic_in_progress_counter<'l>(&self, id: &'l TaskId) -> &'l AtomicU32 + where + Self: 'l, + { + &self + .backend + .task(*id) + .graph_modification_in_progress_counter } fn apply_change( &self, info: &mut Aggregated, - change: &Self::ItemChange, - ) -> Option { + change: &Self::DataChange, + ) -> Option { let mut unfinished = 0; if info.unfinished > 0 { info.unfinished += change.unfinished; @@ -249,15 +247,28 @@ impl<'a> AggregationContext for TaskAggregationContext<'a> { } } #[cfg(feature = "track_unfinished")] + let mut unfinished_tasks_update = Vec::new(); + #[cfg(feature = "track_unfinished")] for &(task, count) in change.unfinished_tasks_update.iter() { - update_count_entry(info.unfinished_tasks.entry(task), count); + match update_count_entry(info.unfinished_tasks.entry(task), count) { + (_, UpdateCountEntryChange::Removed) => unfinished_tasks_update.push((task, -1)), + (_, UpdateCountEntryChange::Inserted) => unfinished_tasks_update.push((task, 1)), + _ => {} + } } + let mut dirty_tasks_update = Vec::new(); let is_root = info.root_type.is_some(); for &(task, count) in change.dirty_tasks_update.iter() { - let value = update_count_entry(info.dirty_tasks.entry(task), count); - if is_root && value > 0 && value <= count { - let mut tasks_to_schedule = self.dirty_tasks_to_schedule.lock(); - tasks_to_schedule.get_or_insert_default().insert(task); + match update_count_entry(info.dirty_tasks.entry(task), count) { + (_, UpdateCountEntryChange::Removed) => dirty_tasks_update.push((task, -1)), + (_, UpdateCountEntryChange::Inserted) => { + if is_root { + let mut tasks_to_schedule = self.dirty_tasks_to_schedule.lock(); + tasks_to_schedule.get_or_insert_default().insert(task); + } + dirty_tasks_update.push((task, 1)) + } + _ => {} } } for &(trait_type_id, collectible, count) in change.collectibles.iter() { @@ -265,7 +276,7 @@ impl<'a> AggregationContext for TaskAggregationContext<'a> { match collectibles_info_entry { Entry::Occupied(mut e) => { let collectibles_info = e.get_mut(); - let value = update_count_entry( + let (value, _) = update_count_entry( collectibles_info.collectibles.entry(collectible), count, ); @@ -298,8 +309,8 @@ impl<'a> AggregationContext for TaskAggregationContext<'a> { let new_change = TaskChange { unfinished, #[cfg(feature = "track_unfinished")] - unfinished_tasks_update: change.unfinished_tasks_update.clone(), - dirty_tasks_update: change.dirty_tasks_update.clone(), + unfinished_tasks_update: unfinished_tasks_update, + dirty_tasks_update, collectibles: change.collectibles.clone(), }; if new_change.is_empty() { @@ -309,19 +320,23 @@ impl<'a> AggregationContext for TaskAggregationContext<'a> { } } - fn info_to_add_change(&self, info: &Aggregated) -> Option { + fn data_to_add_change(&self, data: &Aggregated) -> Option { let mut change = TaskChange::default(); - if info.unfinished > 0 { + if data.unfinished > 0 { change.unfinished = 1; } #[cfg(feature = "track_unfinished")] - for (&task, &count) in info.unfinished_tasks.iter() { - change.unfinished_tasks_update.push((task, count)); + for (&task, &count) in data.unfinished_tasks.iter() { + if count > 0 { + change.unfinished_tasks_update.push((task, 1)); + } } - for (&task, &count) in info.dirty_tasks.iter() { - change.dirty_tasks_update.push((task, count)); + for (&task, &count) in data.dirty_tasks.iter() { + if count > 0 { + change.dirty_tasks_update.push((task, 1)); + } } - for (trait_type_id, collectibles_info) in info.collectibles.iter() { + for (trait_type_id, collectibles_info) in data.collectibles.iter() { for (collectible, count) in collectibles_info.collectibles.iter() { change .collectibles @@ -335,19 +350,19 @@ impl<'a> AggregationContext for TaskAggregationContext<'a> { } } - fn info_to_remove_change(&self, info: &Aggregated) -> Option { + fn data_to_remove_change(&self, data: &Aggregated) -> Option { let mut change = TaskChange::default(); - if info.unfinished > 0 { + if data.unfinished > 0 { change.unfinished = -1; } #[cfg(feature = "track_unfinished")] - for (&task, &count) in info.unfinished_tasks.iter() { + for (&task, &count) in data.unfinished_tasks.iter() { change.unfinished_tasks_update.push((task, -count)); } - for (&task, &count) in info.dirty_tasks.iter() { + for (&task, &count) in data.dirty_tasks.iter() { change.dirty_tasks_update.push((task, -count)); } - for (trait_type_id, collectibles_info) in info.collectibles.iter() { + for (trait_type_id, collectibles_info) in data.collectibles.iter() { for (collectible, count) in collectibles_info.collectibles.iter() { change .collectibles @@ -360,71 +375,84 @@ impl<'a> AggregationContext for TaskAggregationContext<'a> { Some(change) } } +} - fn new_root_info(&self, _root_info_type: &RootInfoType) -> Self::RootInfo { - false - } +#[derive(Default)] +pub struct ActiveQuery { + active: bool, +} - fn info_to_root_info( - &self, - info: &Aggregated, - root_info_type: &RootInfoType, - ) -> Self::RootInfo { - match root_info_type { - RootInfoType::IsActive => info.root_type.is_some(), - } - } +impl RootQuery for ActiveQuery { + type Data = Aggregated; + type Result = bool; - fn merge_root_info( - &self, - root_info: &mut Self::RootInfo, - other: Self::RootInfo, - ) -> std::ops::ControlFlow<()> { - if other { - *root_info = true; - std::ops::ControlFlow::Break(()) + fn query(&mut self, data: &Self::Data) -> ControlFlow<()> { + if data.root_type.is_some() { + self.active = true; + ControlFlow::Break(()) } else { - std::ops::ControlFlow::Continue(()) + ControlFlow::Continue(()) } } + + fn result(self) -> Self::Result { + self.active + } } pub struct TaskGuard<'l> { - pub(super) id: TaskId, - pub(super) guard: TaskMetaStateWriteGuard<'l>, + id: TaskId, + guard: TaskMetaStateWriteGuard<'l>, } -impl<'l> AggregationItemLock for TaskGuard<'l> { - type Info = Aggregated; - type ItemRef = TaskId; - type ItemChange = TaskChange; - type ChildrenIter<'a> = impl Iterator> + 'a where Self: 'a; +impl<'l> TaskGuard<'l> { + pub fn new(id: TaskId, mut guard: TaskMetaStateWriteGuard<'l>) -> Self { + guard.ensure_at_least_partial(); + Self { id, guard } + } - fn leaf(&mut self) -> &mut AggregationTreeLeaf { - self.guard.ensure_at_least_partial(); - match self.guard { - TaskMetaStateWriteGuard::Full(ref mut guard) => &mut guard.aggregation_leaf, - TaskMetaStateWriteGuard::Partial(ref mut guard) => &mut guard.aggregation_leaf, - TaskMetaStateWriteGuard::Unloaded(_) => unreachable!(), + pub fn from_full(id: TaskId, guard: FullTaskWriteGuard<'l>) -> Self { + Self { + id, + guard: TaskMetaStateWriteGuard::Full(guard), } } - fn reference(&self) -> &Self::ItemRef { - &self.id + pub fn into_inner(self) -> TaskMetaStateWriteGuard<'l> { + self.guard } +} - fn number_of_children(&self) -> usize { +impl<'l> Deref for TaskGuard<'l> { + type Target = AggregationNode< + ::NodeRef, + ::Data, + >; + + fn deref(&self) -> &Self::Target { match self.guard { - TaskMetaStateWriteGuard::Full(ref guard) => match &guard.state_type { - #[cfg(feature = "lazy_remove_children")] - TaskStateType::InProgress { - outdated_children, .. - } => guard.children.len() + outdated_children.len(), - _ => guard.children.len(), - }, - TaskMetaStateWriteGuard::Partial(_) | TaskMetaStateWriteGuard::Unloaded(_) => 0, + TaskMetaStateWriteGuard::Full(ref guard) => &guard.aggregation_node, + TaskMetaStateWriteGuard::Partial(ref guard) => &guard.aggregation_leaf, + TaskMetaStateWriteGuard::Unloaded(_) => unreachable!(), + } + } +} + +impl<'l> DerefMut for TaskGuard<'l> { + fn deref_mut(&mut self) -> &mut Self::Target { + match self.guard { + TaskMetaStateWriteGuard::Full(ref mut guard) => &mut guard.aggregation_node, + TaskMetaStateWriteGuard::Partial(ref mut guard) => &mut guard.aggregation_leaf, + TaskMetaStateWriteGuard::Unloaded(_) => unreachable!(), } } +} + +impl<'l> AggregationNodeGuard for TaskGuard<'l> { + type Data = Aggregated; + type NodeRef = TaskId; + type DataChange = TaskChange; + type ChildrenIter<'a> = impl Iterator + 'a where Self: 'a; fn children(&self) -> Self::ChildrenIter<'_> { match self.guard { @@ -434,14 +462,14 @@ impl<'l> AggregationItemLock for TaskGuard<'l> { let outdated_children = match &guard.state_type { TaskStateType::InProgress { outdated_children, .. - } => Some(outdated_children.iter().map(Cow::Borrowed)), + } => Some(outdated_children.iter().copied()), _ => None, }; Some( guard .children .iter() - .map(Cow::Borrowed) + .copied() .chain(outdated_children.into_iter().flatten()), ) .into_iter() @@ -449,9 +477,7 @@ impl<'l> AggregationItemLock for TaskGuard<'l> { } #[cfg(not(feature = "lazy_remove_children"))] { - Some(guard.children.iter().map(Cow::Borrowed)) - .into_iter() - .flatten() + Some(guard.children.iter().copied()).into_iter().flatten() } } TaskMetaStateWriteGuard::Partial(_) | TaskMetaStateWriteGuard::Unloaded(_) => { @@ -460,7 +486,7 @@ impl<'l> AggregationItemLock for TaskGuard<'l> { } } - fn get_add_change(&self) -> Option { + fn get_add_change(&self) -> Option { match self.guard { TaskMetaStateWriteGuard::Full(ref guard) => { let mut change = TaskChange::default(); @@ -505,7 +531,7 @@ impl<'l> AggregationItemLock for TaskGuard<'l> { } } - fn get_remove_change(&self) -> Option { + fn get_remove_change(&self) -> Option { match self.guard { TaskMetaStateWriteGuard::Full(ref guard) => { let mut change = TaskChange::default(); @@ -549,28 +575,84 @@ impl<'l> AggregationItemLock for TaskGuard<'l> { TaskMetaStateWriteGuard::Partial(_) | TaskMetaStateWriteGuard::Unloaded(_) => None, } } + + fn get_initial_data(&self) -> Self::Data { + let mut data = Aggregated::default(); + if let Some(TaskChange { + unfinished, + #[cfg(feature = "track_unfinished")] + unfinished_tasks_update, + dirty_tasks_update, + collectibles, + }) = self.get_add_change() + { + data.unfinished = unfinished; + #[cfg(feature = "track_unfinished")] + { + data.unfinished_tasks = unfinished_tasks_update.into_iter().collect(); + } + data.dirty_tasks = dirty_tasks_update.into_iter().collect(); + data.collectibles = collectibles + .into_iter() + .map(|(trait_type_id, collectible, count)| { + ( + trait_type_id, + CollectiblesInfo { + collectibles: [(collectible, count)].iter().cloned().collect(), + dependent_tasks: TaskIdSet::default(), + }, + ) + }) + .collect(); + } + data + } } -pub type TaskAggregationTreeLeaf = AggregationTreeLeaf; +pub type TaskAggregationNode = AggregationNode; + +enum UpdateCountEntryChange { + Removed, + Inserted, + Updated, +} fn update_count_entry( entry: Entry<'_, K, i32, H, I>, update: i32, -) -> i32 { +) -> (i32, UpdateCountEntryChange) { match entry { Entry::Occupied(mut e) => { let value = e.get_mut(); - *value += update; - if *value == 0 { - e.remove(); - 0 + if *value < 0 { + *value += update; + match (*value).cmp(&0) { + Ordering::Less => (*value, UpdateCountEntryChange::Updated), + Ordering::Equal => { + e.remove(); + (0, UpdateCountEntryChange::Updated) + } + Ordering::Greater => (*value, UpdateCountEntryChange::Inserted), + } } else { - *value + *value += update; + match (*value).cmp(&0) { + Ordering::Less => (*value, UpdateCountEntryChange::Removed), + Ordering::Equal => { + e.remove(); + (0, UpdateCountEntryChange::Removed) + } + Ordering::Greater => (*value, UpdateCountEntryChange::Updated), + } } } Entry::Vacant(e) => { - e.insert(update); - update + if update == 0 { + (0, UpdateCountEntryChange::Updated) + } else { + e.insert(update); + (update, UpdateCountEntryChange::Inserted) + } } } }