From 71b6358a959cfc613456455435b6ffd5afed3cdb Mon Sep 17 00:00:00 2001 From: Longarithm Date: Wed, 18 Sep 2024 00:45:34 +0400 Subject: [PATCH 1/8] stub --- chain/chain/src/chain.rs | 106 +++++++-- core/store/src/trie/mem/mod.rs | 1 + core/store/src/trie/mem/resharding.rs | 328 ++++++++++++++++++++++++++ core/store/src/trie/mem/updating.rs | 54 ++--- core/store/src/trie/mod.rs | 5 + 5 files changed, 453 insertions(+), 41 deletions(-) create mode 100644 core/store/src/trie/mem/resharding.rs diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index e7af6450e02..6a4c32e0d57 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -92,13 +92,16 @@ use near_primitives::views::{ }; use near_store::config::StateSnapshotType; use near_store::flat::{store_helper, FlatStorageReadyStatus, FlatStorageStatus}; -use near_store::get_genesis_state_roots; +use near_store::trie::mem::resharding::RetainMode; +use near_store::trie::mem::updating::apply_memtrie_changes; use near_store::DBCol; +use near_store::{get_genesis_state_roots, PartialStorage}; use node_runtime::bootstrap_congestion_info; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Debug, Formatter}; use std::num::NonZeroUsize; +use std::str::FromStr; use std::sync::Arc; use time::ext::InstantExt as _; use tracing::{debug, debug_span, error, info, warn, Span}; @@ -1847,6 +1850,77 @@ impl Chain { }); } + /// If shard layout changes after the given block, creates temporary + /// memtries for new shards to be able to process them in the next epoch. + /// Note this doesn't complete resharding, proper memtries are to be + /// created later. + fn process_instant_resharding_storage_update( + &mut self, + block: &Block, + shard_uid: ShardUId, + ) -> Result<(), Error> { + let block_hash = block.hash(); + let block_height = block.header().height(); + let prev_hash = block.header().prev_hash(); + if !self.epoch_manager.will_shard_layout_change(prev_hash)? { + return Ok(()); + } + + let next_epoch_id = self.epoch_manager.get_next_epoch_id_from_prev_block(prev_hash)?; + let next_shard_layout = self.epoch_manager.get_shard_layout(&next_epoch_id)?; + let children_shard_uids = + next_shard_layout.get_children_shards_uids(shard_uid.shard_id()).unwrap(); + + // Hack to ensure this logic is not applied before ReshardingV3. + // TODO(#12019): proper logic. + if next_shard_layout.version() < 3 || children_shard_uids.len() == 1 { + return Ok(()); + } + assert_eq!(children_shard_uids.len(), 2); + + let chunk_extra = self.get_chunk_extra(block_hash, &shard_uid)?; + let tries = self.runtime_adapter.get_tries(); + let Some(mem_tries) = tries.get_mem_tries(shard_uid) else { + // TODO(#12019): what if node doesn't have memtrie? just pause + // processing? + return Ok(()); + }; + let boundary_account = AccountId::from_str("boundary.near").unwrap(); + + // TODO(#12019): leave only tracked shards. + for (new_shard_uid, retain_mode) in [ + (children_shard_uids[0], RetainMode::Left), + (children_shard_uids[1], RetainMode::Right), + ] { + let mut mem_tries = mem_tries.write().unwrap(); + let mut mem_trie_update = mem_tries.update(*chunk_extra.state_root(), false)?; + + let (trie_changes, partial_state) = + mem_trie_update.cut(boundary_account.clone(), retain_mode); + let partial_storage = PartialStorage { nodes: partial_state }; + let mem_changes = trie_changes.mem_trie_changes.as_ref().unwrap(); + let new_state_root = apply_memtrie_changes(&mut mem_tries, &mem_changes, block_height); + let mut new_chunk_extra = ChunkExtra::clone(&chunk_extra); + *new_chunk_extra.state_root_mut() = new_state_root; + + let mut chain_store_update = ChainStoreUpdate::new(&mut self.chain_store); + chain_store_update.save_chunk_extra(block_hash, &new_shard_uid, new_chunk_extra); + chain_store_update.save_state_transition_data( + *block_hash, + new_shard_uid.shard_id(), + Some(partial_storage), + CryptoHash::default(), + ); + chain_store_update.commit()?; + + let mut store_update = self.chain_store.store().store_update(); + tries.apply_insertions(&trie_changes, new_shard_uid, &mut store_update); + store_update.commit()?; + } + + Ok(()) + } + #[tracing::instrument(level = "debug", target = "chain", "postprocess_block_only", skip_all)] fn postprocess_block_only( &mut self, @@ -1867,6 +1941,7 @@ impl Chain { should_save_state_transition_data, )?; chain_update.commit()?; + Ok(new_head) } @@ -1936,20 +2011,13 @@ impl Chain { true, ); let care_about_shard_this_or_next_epoch = care_about_shard || will_care_about_shard; + let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &epoch_id).unwrap(); if care_about_shard_this_or_next_epoch { - let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &epoch_id).unwrap(); shards_cares_this_or_next_epoch.push(shard_uid); } - // Update flat storage head to be the last final block. Note that this update happens - // in a separate db transaction from the update from block processing. This is intentional - // because flat_storage need to be locked during the update of flat head, otherwise - // flat_storage is in an inconsistent state that could be accessed by the other - // apply chunks processes. This means, the flat head is not always the same as - // the last final block on chain, which is OK, because in the flat storage implementation - // we don't assume that. - let need_flat_storage_update = if is_caught_up { - // If we already caught up this epoch, then flat storage exists for both shards which we already track + let need_storage_update = if is_caught_up { + // If we already caught up this epoch, then storage exists for both shards which we already track // and shards which will be tracked in next epoch, so we can update them. care_about_shard_this_or_next_epoch } else { @@ -1957,9 +2025,19 @@ impl Chain { // during catchup of this block. care_about_shard }; - tracing::debug!(target: "chain", shard_id, need_flat_storage_update, "Updating flat storage"); - - if need_flat_storage_update { + tracing::debug!(target: "chain", shard_id, need_storage_update, "Updating storage"); + + if need_storage_update { + // TODO(#12019): consider adding to catchup flow. + self.process_instant_resharding_storage_update(&block, shard_uid)?; + + // Update flat storage head to be the last final block. Note that this update happens + // in a separate db transaction from the update from block processing. This is intentional + // because flat_storage need to be locked during the update of flat head, otherwise + // flat_storage is in an inconsistent state that could be accessed by the other + // apply chunks processes. This means, the flat head is not always the same as + // the last final block on chain, which is OK, because in the flat storage implementation + // we don't assume that. self.update_flat_storage_and_memtrie(&block, shard_id)?; } } diff --git a/core/store/src/trie/mem/mod.rs b/core/store/src/trie/mem/mod.rs index 73146e3a145..c2fc879cc18 100644 --- a/core/store/src/trie/mem/mod.rs +++ b/core/store/src/trie/mem/mod.rs @@ -18,6 +18,7 @@ pub mod lookup; pub mod metrics; pub mod node; mod parallel_loader; +pub mod resharding; pub mod updating; /// Check this, because in the code we conveniently assume usize is 8 bytes. diff --git a/core/store/src/trie/mem/resharding.rs b/core/store/src/trie/mem/resharding.rs new file mode 100644 index 00000000000..b20654d7c92 --- /dev/null +++ b/core/store/src/trie/mem/resharding.rs @@ -0,0 +1,328 @@ +use crate::trie::{MemTrieChanges, TrieRefcountDeltaMap}; +use crate::{NibbleSlice, TrieChanges}; + +use super::flexible_data::children::ChildrenView; +use super::flexible_data::value::ValueView; +use super::node::MemTrieNodeView; +use super::updating::{ + MemTrieUpdate, OldOrUpdatedNodeId, UpdatedMemTrieNode, UpdatedMemTrieNodeId, +}; +use super::{arena::STArenaMemory, node::MemTrieNodePtr}; +use itertools::Itertools; +use near_primitives::challenge::PartialState; +use near_primitives::hash::CryptoHash; +use near_primitives::types::AccountId; +use std::collections::HashMap; +use std::ops::Range; +use std::sync::Arc; + +/// Whether to retain left or right part of trie after shard split. +pub enum RetainMode { + Left, + Right, +} + +/// Decision on the subtree exploration. +#[derive(Debug)] +enum DeleteDecision { + /// Remove the whole subtree. + DeleteAll, + /// Descend into all child subtrees. + Descend, + /// Skip subtree, it is not impacted by deletion. + Skip, +} + +/// Tracks changes to the trie caused by the deletion. +struct UpdatesTracker { + #[allow(unused)] + node_accesses: HashMap>, + ordered_nodes: Vec, + updated_nodes: Vec>, + /// On-disk changes applied to reference counts of nodes. In fact, these + /// are only increments. + refcount_changes: TrieRefcountDeltaMap, +} + +impl UpdatesTracker { + pub fn new() -> Self { + Self { + node_accesses: HashMap::new(), + ordered_nodes: Vec::new(), + updated_nodes: Vec::new(), + refcount_changes: TrieRefcountDeltaMap::new(), + } + } +} + +impl<'a> MemTrieUpdate<'a> { + /// Cut the trie, separating entries by the boundary account. + /// Leaves the left or right part of the trie, depending on the retain mode. + /// + /// Returns the changes to be applied to in-memory trie and the proof of + /// the cut operation. + pub fn cut( + &'a mut self, + _boundary_account: AccountId, + _retain_mode: RetainMode, + ) -> (TrieChanges, PartialState) { + // TODO(#12074): generate intervals in nibbles. + + self.delete_multi_range(&[]) + } + + /// Deletes keys belonging to any of the ranges given in `intervals` from + /// the trie. + /// + /// Returns changes to be applied to in-memory trie and proof of the + /// removal operation. + fn delete_multi_range( + &'a mut self, + intervals: &[Range>], + ) -> (TrieChanges, PartialState) { + let intervals_nibbles = intervals + .iter() + .map(|range| { + NibbleSlice::new(&range.start).iter().collect_vec() + ..NibbleSlice::new(&range.end).iter().collect_vec() + }) + .collect_vec(); + let mut updates_tracker = UpdatesTracker::new(); + let root = self.get_root().unwrap(); + // TODO(#12074): consider handling the case when no changes are made. + let _ = + delete_multi_range_recursive(root, vec![], &intervals_nibbles, &mut updates_tracker); + + let UpdatesTracker { ordered_nodes, updated_nodes, refcount_changes, .. } = updates_tracker; + let nodes_hashes_and_serialized = + self.compute_hashes_and_serialized_nodes(&ordered_nodes, &updated_nodes); + let node_ids_with_hashes = nodes_hashes_and_serialized + .iter() + .map(|(node_id, hash, _)| (*node_id, *hash)) + .collect(); + let memtrie_changes = MemTrieChanges { node_ids_with_hashes, updated_nodes }; + + let (trie_insertions, _) = TrieRefcountDeltaMap::into_changes(refcount_changes); + let trie_changes = TrieChanges { + // TODO(#12074): all the default fields are not used, consider + // using simpler struct. + old_root: CryptoHash::default(), + new_root: CryptoHash::default(), + insertions: trie_insertions, + deletions: Vec::default(), + mem_trie_changes: Some(memtrie_changes), + }; + + // TODO(#12074): restore proof as well. + (trie_changes, PartialState::default()) + } +} + +/// Recursive implementation of the algorithm of deleting keys belonging to +/// any of the ranges given in `intervals` from the trie. +/// +/// `root` is the root of subtree being explored. +/// `key_nibbles` is the key corresponding to `root`. +/// `intervals_nibbles` is the list of ranges to be deleted. +/// `updates_tracker` track changes to the trie caused by the deletion. +/// +/// Returns id of the node after deletion applied. +fn delete_multi_range_recursive<'a>( + root: MemTrieNodePtr<'a, STArenaMemory>, + key_nibbles: Vec, + intervals_nibbles: &[Range>], + updates_tracker: &mut UpdatesTracker, +) -> Option { + let decision = delete_decision(&key_nibbles, intervals_nibbles); + match decision { + DeleteDecision::Skip => return Some(OldOrUpdatedNodeId::Old(root.id())), + DeleteDecision::DeleteAll => { + return None; + } + DeleteDecision::Descend => {} + } + + let node_view = root.view(); + + let mut resolve_branch = |children: &ChildrenView<'a, STArenaMemory>, + mut value: Option<&ValueView>| { + let mut new_children = [None; 16]; + let mut changed = false; + + if intervals_nibbles.iter().any(|interval| interval.contains(&key_nibbles)) { + value = None; + changed = true; + } + + for i in 0..16 { + if let Some(child) = children.get(i) { + let child_key_nibbles = [key_nibbles.clone(), vec![i as u8]].concat(); + let new_child = delete_multi_range_recursive( + child, + child_key_nibbles, + intervals_nibbles, + updates_tracker, + ); + match new_child { + Some(OldOrUpdatedNodeId::Old(id)) => { + new_children[i] = Some(OldOrUpdatedNodeId::Old(id)); + } + Some(OldOrUpdatedNodeId::Updated(id)) => { + changed = true; + new_children[i] = Some(OldOrUpdatedNodeId::Updated(id)); + } + None => { + changed = true; + new_children[i] = None; + } + } + } + } + + if changed { + let new_node = UpdatedMemTrieNode::Branch { + children: Box::new(new_children), + value: value.map(|v| v.to_flat_value()), + }; + // TODO(#12074): squash the branch if needed. + // TODO(#12074): return None if needed. + Some(OldOrUpdatedNodeId::Updated(add_node(updates_tracker, new_node))) + } else { + Some(OldOrUpdatedNodeId::Old(root.id())) + } + }; + + match node_view { + MemTrieNodeView::Leaf { extension, .. } => { + let extension = NibbleSlice::from_encoded(extension).0; + let full_key_nibbles = [key_nibbles, extension.iter().collect_vec()].concat(); + if intervals_nibbles.iter().any(|interval| interval.contains(&full_key_nibbles)) { + None + } else { + Some(OldOrUpdatedNodeId::Old(root.id())) + } + } + MemTrieNodeView::Branch { children, .. } => resolve_branch(&children, None), + MemTrieNodeView::BranchWithValue { children, value, .. } => { + resolve_branch(&children, Some(&value)) + } + MemTrieNodeView::Extension { extension, child, .. } => { + let extension_nibbles = NibbleSlice::from_encoded(extension).0.iter().collect_vec(); + let child_key = [key_nibbles, extension_nibbles].concat(); + let new_child = + delete_multi_range_recursive(child, child_key, intervals_nibbles, updates_tracker); + + match new_child { + None => None, + Some(OldOrUpdatedNodeId::Old(id)) => Some(OldOrUpdatedNodeId::Old(id)), + Some(OldOrUpdatedNodeId::Updated(id)) => { + let new_node = UpdatedMemTrieNode::Extension { + extension: extension.to_vec().into_boxed_slice(), + child: OldOrUpdatedNodeId::Updated(id), + }; + Some(OldOrUpdatedNodeId::Updated(add_node(updates_tracker, new_node))) + } + } + } + } +} + +fn add_node( + updates_tracker: &mut UpdatesTracker, + node: UpdatedMemTrieNode, +) -> UpdatedMemTrieNodeId { + debug_assert!(node != UpdatedMemTrieNode::Empty); + let id = updates_tracker.ordered_nodes.len(); + updates_tracker.ordered_nodes.push(id); + updates_tracker.updated_nodes.push(Some(node)); + // TODO(#12074): apply remaining changes to `updates_tracker`. + + id +} + +/// Based on the key and the intervals, makes decision on the subtree exploration. +fn delete_decision(key: &[u8], intervals: &[Range>]) -> DeleteDecision { + let mut should_descend = false; + for interval in intervals { + if key < interval.start.as_slice() { + if interval.start.starts_with(key) { + should_descend = true; + } else { + // Skip + } + } else { + if key >= interval.end.as_slice() { + // Skip + } else if interval.end.starts_with(key) { + should_descend = true; + } else { + return DeleteDecision::DeleteAll; + } + } + } + + if should_descend { + DeleteDecision::Descend + } else { + DeleteDecision::Skip + } +} + +// TODO(#12074): tests for +// - multiple removal ranges +// - no-op removal +// - removing keys one-by-one gives the same result as range removal +// - `cut` API +// - all results of squashing branch +// - checking not accessing not-inlined nodes +// - proof correctness +// - (maybe) removal of the whole trie +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use itertools::Itertools; + use near_primitives::{shard_layout::ShardUId, types::StateRoot}; + + use crate::{ + trie::{ + mem::{iter::MemTrieIterator, updating::apply_memtrie_changes, MemTries}, + trie_storage::TrieMemoryPartialStorage, + }, + Trie, + }; + + #[test] + /// Applies single range removal to the trie and checks the result. + fn test_delete_multi_range() { + let initial_entries = vec![ + (b"alice".to_vec(), vec![1]), + (b"bob".to_vec(), vec![2]), + (b"charlie".to_vec(), vec![3]), + (b"david".to_vec(), vec![4]), + ]; + let removal_range = b"bob".to_vec()..b"collin".to_vec(); + let removal_result = vec![(b"alice".to_vec(), vec![1]), (b"david".to_vec(), vec![4])]; + + let mut memtries = MemTries::new(ShardUId::single_shard()); + let empty_state_root = StateRoot::default(); + let mut update = memtries.update(empty_state_root, false).unwrap(); + for (key, value) in initial_entries { + update.insert(&key, value); + } + let memtrie_changes = update.to_mem_trie_changes_only(); + let state_root = apply_memtrie_changes(&mut memtries, &memtrie_changes, 0); + + let mut update = memtries.update(state_root, false).unwrap(); + let (mut trie_changes, _) = update.delete_multi_range(&[removal_range]); + let memtrie_changes = trie_changes.mem_trie_changes.take().unwrap(); + let new_state_root = apply_memtrie_changes(&mut memtries, &memtrie_changes, 1); + + let state_root_ptr = memtries.get_root(&new_state_root).unwrap(); + let trie = Trie::new(Arc::new(TrieMemoryPartialStorage::default()), new_state_root, None); + let entries = + MemTrieIterator::new(Some(state_root_ptr), &trie).map(|e| e.unwrap()).collect_vec(); + + assert_eq!(entries, removal_result); + } +} diff --git a/core/store/src/trie/mem/updating.rs b/core/store/src/trie/mem/updating.rs index 09ff28e3b99..b3fd10c20f9 100644 --- a/core/store/src/trie/mem/updating.rs +++ b/core/store/src/trie/mem/updating.rs @@ -1,7 +1,7 @@ use super::arena::STArenaMemory; use super::flexible_data::children::ChildrenView; use super::metrics::MEM_TRIE_NUM_NODES_CREATED_FROM_UPDATES; -use super::node::{InputMemTrieNode, MemTrieNodeId, MemTrieNodeView}; +use super::node::{InputMemTrieNode, MemTrieNodeId, MemTrieNodePtr, MemTrieNodeView}; use super::MemTries; use crate::trie::{Children, MemTrieChanges, TrieRefcountDeltaMap, TRIE_COSTS}; use crate::{NibbleSlice, RawTrieNode, RawTrieNodeWithSize, TrieChanges}; @@ -143,6 +143,10 @@ impl<'a> MemTrieUpdate<'a> { trie_update } + pub fn get_root(&self) -> Option> { + self.root.map(|id| id.as_ptr(self.arena)) + } + /// Internal function to take a node from the array of updated nodes, setting it /// to None. It is expected that place_node is then called to return the node to /// the same slot. @@ -691,13 +695,15 @@ impl<'a> MemTrieUpdate<'a> { } /// For each node in `ordered_nodes`, computes its hash and serialized data. - /// The `ordered_nodes` is expected to come from `post_order_traverse_updated_nodes`, - /// and updated_nodes are indexed by the node IDs in `ordered_nodes`. - fn compute_hashes_and_serialized_nodes( + /// `ordered_nodes` is expected to follow the post-order traversal of the + /// updated nodes. + /// `updated_nodes` must be indexed by the node IDs in `ordered_nodes`. + pub(crate) fn compute_hashes_and_serialized_nodes( + &self, ordered_nodes: &Vec, updated_nodes: &Vec>, - arena: &STArenaMemory, ) -> Vec<(UpdatedMemTrieNodeId, CryptoHash, Vec)> { + let arena = self.arena; let mut result = Vec::<(CryptoHash, u64, Vec)>::new(); for _ in 0..updated_nodes.len() { result.push((CryptoHash::default(), 0, Vec::new())); @@ -782,26 +788,22 @@ impl<'a> MemTrieUpdate<'a> { /// Converts the changes to memtrie changes. Also returns the list of new nodes inserted, /// in hash and serialized form. - fn to_mem_trie_changes_internal( - shard_uid: String, - arena: &STArenaMemory, - updated_nodes: Vec>, - ) -> (MemTrieChanges, Vec<(CryptoHash, Vec)>) { + fn to_mem_trie_changes_internal(self) -> (MemTrieChanges, Vec<(CryptoHash, Vec)>) { MEM_TRIE_NUM_NODES_CREATED_FROM_UPDATES - .with_label_values(&[&shard_uid]) - .inc_by(updated_nodes.len() as u64); + .with_label_values(&[&self.shard_uid]) + .inc_by(self.updated_nodes.len() as u64); let mut ordered_nodes = Vec::new(); - Self::post_order_traverse_updated_nodes(0, &updated_nodes, &mut ordered_nodes); + Self::post_order_traverse_updated_nodes(0, &self.updated_nodes, &mut ordered_nodes); let nodes_hashes_and_serialized = - Self::compute_hashes_and_serialized_nodes(&ordered_nodes, &updated_nodes, arena); + self.compute_hashes_and_serialized_nodes(&ordered_nodes, &self.updated_nodes); let node_ids_with_hashes = nodes_hashes_and_serialized .iter() .map(|(node_id, hash, _)| (*node_id, *hash)) .collect(); ( - MemTrieChanges { node_ids_with_hashes, updated_nodes }, + MemTrieChanges { node_ids_with_hashes, updated_nodes: self.updated_nodes }, nodes_hashes_and_serialized .into_iter() .map(|(_, hash, serialized)| (hash, serialized)) @@ -811,19 +813,19 @@ impl<'a> MemTrieUpdate<'a> { /// Converts the updates to memtrie changes only. pub fn to_mem_trie_changes_only(self) -> MemTrieChanges { - let Self { arena, updated_nodes, shard_uid, .. } = self; - let (mem_trie_changes, _) = - Self::to_mem_trie_changes_internal(shard_uid, arena, updated_nodes); + let (mem_trie_changes, _) = self.to_mem_trie_changes_internal(); mem_trie_changes } /// Converts the updates to trie changes as well as memtrie changes. - pub(crate) fn to_trie_changes(self) -> (TrieChanges, TrieAccesses) { - let Self { root, arena, shard_uid, tracked_trie_changes, updated_nodes } = self; - let TrieChangesTracker { mut refcount_changes, accesses } = - tracked_trie_changes.expect("Cannot to_trie_changes for memtrie changes only"); - let (mem_trie_changes, hashes_and_serialized) = - Self::to_mem_trie_changes_internal(shard_uid, arena, updated_nodes); + pub(crate) fn to_trie_changes(mut self) -> (TrieChanges, TrieAccesses) { + let old_root = + self.root.map(|root| root.as_ptr(self.arena).view().node_hash()).unwrap_or_default(); + let TrieChangesTracker { mut refcount_changes, accesses } = self + .tracked_trie_changes + .take() + .expect("Cannot to_trie_changes for memtrie changes only"); + let (mem_trie_changes, hashes_and_serialized) = self.to_mem_trie_changes_internal(); // We've accounted for the dereferenced nodes, as well as value addition/subtractions. // The only thing left is to increment refcount for all new nodes. @@ -834,9 +836,7 @@ impl<'a> MemTrieUpdate<'a> { ( TrieChanges { - old_root: root - .map(|root| root.as_ptr(arena).view().node_hash()) - .unwrap_or_default(), + old_root, new_root: mem_trie_changes .node_ids_with_hashes .last() diff --git a/core/store/src/trie/mod.rs b/core/store/src/trie/mod.rs index a38cc0f407b..12679530174 100644 --- a/core/store/src/trie/mod.rs +++ b/core/store/src/trie/mod.rs @@ -519,8 +519,13 @@ impl TrieRefcountDeltaMap { } } +/// Changes to be applied to in-memory trie. +/// Result is the new state root attached to existing persistent trie structure. #[derive(Default, Clone, PartialEq, Eq, Debug)] pub struct MemTrieChanges { + /// Node ids with hashes of updated nodes. + /// Should be in the post-order traversal of the updated nodes. + /// It implies that the root node is the last one in the list. node_ids_with_hashes: Vec<(UpdatedMemTrieNodeId, CryptoHash)>, updated_nodes: Vec>, } From 8b560cfb5d682ff8e91e9a2fd38eb4243079562e Mon Sep 17 00:00:00 2001 From: Longarithm Date: Thu, 19 Sep 2024 01:13:02 +0400 Subject: [PATCH 2/8] feedback --- chain/chain/src/chain.rs | 24 ++- core/store/src/trie/mem/resharding.rs | 238 ++++++++++++++------------ core/store/src/trie/mem/updating.rs | 6 +- 3 files changed, 149 insertions(+), 119 deletions(-) diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index 93275259e15..3ee23f367a3 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -1853,7 +1853,7 @@ impl Chain { /// memtries for new shards to be able to process them in the next epoch. /// Note this doesn't complete resharding, proper memtries are to be /// created later. - fn process_instant_resharding_storage_update( + fn process_memtrie_resharding_storage_update( &mut self, block: &Block, shard_uid: ShardUId, @@ -1882,8 +1882,15 @@ impl Chain { let Some(mem_tries) = tries.get_mem_tries(shard_uid) else { // TODO(#12019): what if node doesn't have memtrie? just pause // processing? - return Ok(()); + error!( + "Memtrie not loaded. Cannot process memtrie resharding storage + update for block {:?}, shard {:?}", + block_hash, shard_uid + ); + return Err(Error::Other("Memtrie not loaded".to_string())); }; + + // TODO(#12019): take proper boundary account. let boundary_account = AccountId::from_str("boundary.near").unwrap(); // TODO(#12019): leave only tracked shards. @@ -1895,15 +1902,18 @@ impl Chain { let mut mem_trie_update = mem_tries.update(*chunk_extra.state_root(), false)?; let (trie_changes, partial_state) = - mem_trie_update.cut(boundary_account.clone(), retain_mode); + mem_trie_update.retain_split_shard(boundary_account.clone(), retain_mode); let partial_storage = PartialStorage { nodes: partial_state }; let mem_changes = trie_changes.mem_trie_changes.as_ref().unwrap(); let new_state_root = mem_tries.apply_memtrie_changes(block_height, mem_changes); - let mut new_chunk_extra = ChunkExtra::clone(&chunk_extra); - *new_chunk_extra.state_root_mut() = new_state_root; + // TODO(#12019): set all fields of `ChunkExtra`. Consider stronger + // typing. Clarify where it should happen when `State` and + // `FlatState` update is implemented. + let mut child_chunk_extra = ChunkExtra::clone(&chunk_extra); + *child_chunk_extra.state_root_mut() = new_state_root; let mut chain_store_update = ChainStoreUpdate::new(&mut self.chain_store); - chain_store_update.save_chunk_extra(block_hash, &new_shard_uid, new_chunk_extra); + chain_store_update.save_chunk_extra(block_hash, &new_shard_uid, child_chunk_extra); chain_store_update.save_state_transition_data( *block_hash, new_shard_uid.shard_id(), @@ -2028,7 +2038,7 @@ impl Chain { if need_storage_update { // TODO(#12019): consider adding to catchup flow. - self.process_instant_resharding_storage_update(&block, shard_uid)?; + self.process_memtrie_resharding_storage_update(&block, shard_uid)?; // Update flat storage head to be the last final block. Note that this update happens // in a separate db transaction from the update from block processing. This is intentional diff --git a/core/store/src/trie/mem/resharding.rs b/core/store/src/trie/mem/resharding.rs index 9dd041513a6..d434100b703 100644 --- a/core/store/src/trie/mem/resharding.rs +++ b/core/store/src/trie/mem/resharding.rs @@ -3,7 +3,7 @@ use crate::{NibbleSlice, TrieChanges}; use super::flexible_data::children::ChildrenView; use super::flexible_data::value::ValueView; -use super::node::MemTrieNodeView; +use super::node::{MemTrieNodeId, MemTrieNodeView}; use super::updating::{ MemTrieUpdate, OldOrUpdatedNodeId, UpdatedMemTrieNode, UpdatedMemTrieNodeId, }; @@ -24,21 +24,23 @@ pub enum RetainMode { /// Decision on the subtree exploration. #[derive(Debug)] -enum DeleteDecision { - /// Remove the whole subtree. - DeleteAll, +enum RetainDecision { + /// Retain the whole subtree. + RetainAll, + /// The whole subtree is not retained. + NoRetain, /// Descend into all child subtrees. Descend, - /// Skip subtree, it is not impacted by deletion. - Skip, } -/// Tracks changes to the trie caused by the deletion. +/// Tracks changes to the trie caused by the retain. struct UpdatesTracker { + /// Accessed node hashes and their serializations. Used for proof + /// generation. #[allow(unused)] node_accesses: HashMap>, - ordered_nodes: Vec, - updated_nodes: Vec>, + /// All new nodes to be created. + updated_nodes: Vec, /// On-disk changes applied to reference counts of nodes. In fact, these /// are only increments. refcount_changes: TrieRefcountDeltaMap, @@ -48,35 +50,43 @@ impl UpdatesTracker { pub fn new() -> Self { Self { node_accesses: HashMap::new(), - ordered_nodes: Vec::new(), updated_nodes: Vec::new(), refcount_changes: TrieRefcountDeltaMap::new(), } } + + pub fn add_node(&mut self, node: UpdatedMemTrieNode) -> UpdatedMemTrieNodeId { + let id = self.updated_nodes.len(); + debug_assert!(node != UpdatedMemTrieNode::Empty); + self.updated_nodes.push(node); + // TODO(#12074): apply remaining changes to `updates_tracker`. + + id + } } impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { - /// Cut the trie, separating entries by the boundary account. + /// Splits the trie, separating entries by the boundary account. /// Leaves the left or right part of the trie, depending on the retain mode. /// /// Returns the changes to be applied to in-memory trie and the proof of - /// the cut operation. - pub fn cut( + /// the split operation. + pub fn retain_split_shard( &'a mut self, _boundary_account: AccountId, _retain_mode: RetainMode, ) -> (TrieChanges, PartialState) { // TODO(#12074): generate intervals in nibbles. - self.delete_multi_range(&[]) + self.retain_multi_range(&[]) } - /// Deletes keys belonging to any of the ranges given in `intervals` from + /// Retains keys belonging to any of the ranges given in `intervals` from /// the trie. /// /// Returns changes to be applied to in-memory trie and proof of the - /// removal operation. - fn delete_multi_range( + /// retain operation. + fn retain_multi_range( &'a mut self, intervals: &[Range>], ) -> (TrieChanges, PartialState) { @@ -91,12 +101,16 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { let root = self.get_root().unwrap(); // TODO(#12074): consider handling the case when no changes are made. let _ = - delete_multi_range_recursive(root, vec![], &intervals_nibbles, &mut updates_tracker); + retain_multi_range_recursive(root, vec![], &intervals_nibbles, &mut updates_tracker); - let UpdatesTracker { ordered_nodes, updated_nodes, refcount_changes, .. } = updates_tracker; - let nodes_hashes_and_serialized = + let UpdatesTracker { updated_nodes, refcount_changes, .. } = updates_tracker; + // TODO(#12074): the next method requires more generic node structure. + // Consider simplifying the interface. + let ordered_nodes = (0..updated_nodes.len()).collect_vec(); + let updated_nodes = updated_nodes.into_iter().map(Some).collect_vec(); + let hashes_and_serialized_nodes = self.compute_hashes_and_serialized_nodes(&ordered_nodes, &updated_nodes); - let node_ids_with_hashes = nodes_hashes_and_serialized + let node_ids_with_hashes = hashes_and_serialized_nodes .iter() .map(|(node_id, hash, _)| (*node_id, *hash)) .collect(); @@ -118,98 +132,60 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { } } -/// Recursive implementation of the algorithm of deleting keys belonging to +/// Recursive implementation of the algorithm of retaining keys belonging to /// any of the ranges given in `intervals` from the trie. /// /// `root` is the root of subtree being explored. /// `key_nibbles` is the key corresponding to `root`. -/// `intervals_nibbles` is the list of ranges to be deleted. -/// `updates_tracker` track changes to the trie caused by the deletion. +/// `intervals_nibbles` is the list of ranges to be retained. +/// `updates_tracker` track changes to the trie caused by the retain. /// -/// Returns id of the node after deletion applied. -fn delete_multi_range_recursive<'a, M: ArenaMemory>( +/// Returns id of the node after retain applied. +fn retain_multi_range_recursive<'a, M: ArenaMemory>( root: MemTrieNodePtr<'a, M>, key_nibbles: Vec, intervals_nibbles: &[Range>], updates_tracker: &mut UpdatesTracker, ) -> Option { - let decision = delete_decision(&key_nibbles, intervals_nibbles); + let decision = retain_decision(&key_nibbles, intervals_nibbles); match decision { - DeleteDecision::Skip => return Some(OldOrUpdatedNodeId::Old(root.id())), - DeleteDecision::DeleteAll => { - return None; - } - DeleteDecision::Descend => {} + RetainDecision::RetainAll => return Some(OldOrUpdatedNodeId::Old(root.id())), + RetainDecision::NoRetain => return None, + RetainDecision::Descend => {} } let node_view = root.view(); - let mut resolve_branch = |children: &ChildrenView<'a, M>, mut value: Option<&ValueView>| { - let mut new_children = [None; 16]; - let mut changed = false; - - if intervals_nibbles.iter().any(|interval| interval.contains(&key_nibbles)) { - value = None; - changed = true; - } - - for i in 0..16 { - if let Some(child) = children.get(i) { - let child_key_nibbles = [key_nibbles.clone(), vec![i as u8]].concat(); - let new_child = delete_multi_range_recursive( - child, - child_key_nibbles, - intervals_nibbles, - updates_tracker, - ); - match new_child { - Some(OldOrUpdatedNodeId::Old(id)) => { - new_children[i] = Some(OldOrUpdatedNodeId::Old(id)); - } - Some(OldOrUpdatedNodeId::Updated(id)) => { - changed = true; - new_children[i] = Some(OldOrUpdatedNodeId::Updated(id)); - } - None => { - changed = true; - new_children[i] = None; - } - } - } - } - - if changed { - let new_node = UpdatedMemTrieNode::Branch { - children: Box::new(new_children), - value: value.map(|v| v.to_flat_value()), - }; - // TODO(#12074): squash the branch if needed. - // TODO(#12074): return None if needed. - Some(OldOrUpdatedNodeId::Updated(add_node(updates_tracker, new_node))) - } else { - Some(OldOrUpdatedNodeId::Old(root.id())) - } + let mut retain_in_branch = |children: &ChildrenView<'a, M>, value: Option<&ValueView>| { + retain_multi_range_in_branch( + root.id(), + children, + value, + key_nibbles.clone(), + intervals_nibbles, + updates_tracker, + ) }; match node_view { MemTrieNodeView::Leaf { extension, .. } => { let extension = NibbleSlice::from_encoded(extension).0; let full_key_nibbles = [key_nibbles, extension.iter().collect_vec()].concat(); - if intervals_nibbles.iter().any(|interval| interval.contains(&full_key_nibbles)) { + if !intervals_nibbles.iter().any(|interval| interval.contains(&full_key_nibbles)) { None } else { Some(OldOrUpdatedNodeId::Old(root.id())) } } - MemTrieNodeView::Branch { children, .. } => resolve_branch(&children, None), + MemTrieNodeView::Branch { children, .. } => retain_in_branch(&children, None), MemTrieNodeView::BranchWithValue { children, value, .. } => { - resolve_branch(&children, Some(&value)) + retain_in_branch(&children, Some(&value)) } MemTrieNodeView::Extension { extension, child, .. } => { let extension_nibbles = NibbleSlice::from_encoded(extension).0.iter().collect_vec(); let child_key = [key_nibbles, extension_nibbles].concat(); let new_child = - delete_multi_range_recursive(child, child_key, intervals_nibbles, updates_tracker); + retain_multi_range_recursive(child, child_key, intervals_nibbles, updates_tracker); match new_child { None => None, @@ -219,63 +195,107 @@ fn delete_multi_range_recursive<'a, M: ArenaMemory>( extension: extension.to_vec().into_boxed_slice(), child: OldOrUpdatedNodeId::Updated(id), }; - Some(OldOrUpdatedNodeId::Updated(add_node(updates_tracker, new_node))) + Some(OldOrUpdatedNodeId::Updated(updates_tracker.add_node(new_node))) } } } } } -fn add_node( +/// Helper function for `retain_multi_range_recursive` when subtree is rooted +/// at a branch. +fn retain_multi_range_in_branch<'a, M: ArenaMemory>( + root_id: MemTrieNodeId, + children: &ChildrenView<'a, M>, + mut value: Option<&ValueView>, + key_nibbles: Vec, + intervals_nibbles: &[Range>], updates_tracker: &mut UpdatesTracker, - node: UpdatedMemTrieNode, -) -> UpdatedMemTrieNodeId { - debug_assert!(node != UpdatedMemTrieNode::Empty); - let id = updates_tracker.ordered_nodes.len(); - updates_tracker.ordered_nodes.push(id); - updates_tracker.updated_nodes.push(Some(node)); - // TODO(#12074): apply remaining changes to `updates_tracker`. - - id +) -> Option { + let mut new_children = [None; 16]; + let mut changed = false; + + if !intervals_nibbles.iter().any(|interval| interval.contains(&key_nibbles)) { + value = None; + changed = true; + } + + for i in 0..16 { + let Some(child) = children.get(i) else { + continue; + }; + + let child_key_nibbles = [key_nibbles.clone(), vec![i as u8]].concat(); + let new_child = retain_multi_range_recursive( + child, + child_key_nibbles, + intervals_nibbles, + updates_tracker, + ); + match new_child { + Some(OldOrUpdatedNodeId::Old(id)) => { + new_children[i] = Some(OldOrUpdatedNodeId::Old(id)); + } + Some(OldOrUpdatedNodeId::Updated(id)) => { + changed = true; + new_children[i] = Some(OldOrUpdatedNodeId::Updated(id)); + } + None => { + changed = true; + new_children[i] = None; + } + } + } + + if changed { + let new_node = UpdatedMemTrieNode::Branch { + children: Box::new(new_children), + value: value.map(|v| v.to_flat_value()), + }; + // TODO(#12074): squash the branch if needed. + // TODO(#12074): return None if needed. + Some(OldOrUpdatedNodeId::Updated(updates_tracker.add_node(new_node))) + } else { + Some(OldOrUpdatedNodeId::Old(root_id)) + } } /// Based on the key and the intervals, makes decision on the subtree exploration. -fn delete_decision(key: &[u8], intervals: &[Range>]) -> DeleteDecision { +fn retain_decision(key: &[u8], intervals: &[Range>]) -> RetainDecision { let mut should_descend = false; for interval in intervals { if key < interval.start.as_slice() { if interval.start.starts_with(key) { should_descend = true; } else { - // Skip + // No retain for this interval. } } else { if key >= interval.end.as_slice() { - // Skip + // No retain for this interval. } else if interval.end.starts_with(key) { should_descend = true; } else { - return DeleteDecision::DeleteAll; + return RetainDecision::RetainAll; } } } if should_descend { - DeleteDecision::Descend + RetainDecision::Descend } else { - DeleteDecision::Skip + RetainDecision::NoRetain } } // TODO(#12074): tests for -// - multiple removal ranges -// - no-op removal -// - removing keys one-by-one gives the same result as range removal -// - `cut` API +// - multiple retain ranges +// - removing keys one-by-one gives the same result as corresponding range retain +// - `retain_split_shard` API // - all results of squashing branch // - checking not accessing not-inlined nodes // - proof correctness -// - (maybe) removal of the whole trie +// - (maybe) retain result is empty or complete tree #[cfg(test)] mod tests { use std::sync::Arc; @@ -292,16 +312,16 @@ mod tests { }; #[test] - /// Applies single range removal to the trie and checks the result. - fn test_delete_multi_range() { + /// Applies single range retain to the trie and checks the result. + fn test_retain_single_range() { let initial_entries = vec![ (b"alice".to_vec(), vec![1]), (b"bob".to_vec(), vec![2]), (b"charlie".to_vec(), vec![3]), (b"david".to_vec(), vec![4]), ]; - let removal_range = b"bob".to_vec()..b"collin".to_vec(); - let removal_result = vec![(b"alice".to_vec(), vec![1]), (b"david".to_vec(), vec![4])]; + let retain_range = b"amy".to_vec()..b"david".to_vec(); + let retain_result = vec![(b"bob".to_vec(), vec![2]), (b"charlie".to_vec(), vec![3])]; let mut memtries = MemTries::new(ShardUId::single_shard()); let empty_state_root = StateRoot::default(); @@ -313,7 +333,7 @@ mod tests { let state_root = memtries.apply_memtrie_changes(0, &memtrie_changes); let mut update = memtries.update(state_root, false).unwrap(); - let (mut trie_changes, _) = update.delete_multi_range(&[removal_range]); + let (mut trie_changes, _) = update.retain_multi_range(&[retain_range]); let memtrie_changes = trie_changes.mem_trie_changes.take().unwrap(); let new_state_root = memtries.apply_memtrie_changes(1, &memtrie_changes); @@ -322,6 +342,6 @@ mod tests { let entries = MemTrieIterator::new(Some(state_root_ptr), &trie).map(|e| e.unwrap()).collect_vec(); - assert_eq!(entries, removal_result); + assert_eq!(entries, retain_result); } } diff --git a/core/store/src/trie/mem/updating.rs b/core/store/src/trie/mem/updating.rs index 8195dbd3acc..aef3613e39e 100644 --- a/core/store/src/trie/mem/updating.rs +++ b/core/store/src/trie/mem/updating.rs @@ -793,16 +793,16 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { let mut ordered_nodes = Vec::new(); Self::post_order_traverse_updated_nodes(0, &self.updated_nodes, &mut ordered_nodes); - let nodes_hashes_and_serialized = + let hashes_and_serialized_nodes = self.compute_hashes_and_serialized_nodes(&ordered_nodes, &self.updated_nodes); - let node_ids_with_hashes = nodes_hashes_and_serialized + let node_ids_with_hashes = hashes_and_serialized_nodes .iter() .map(|(node_id, hash, _)| (*node_id, *hash)) .collect(); ( MemTrieChanges { node_ids_with_hashes, updated_nodes: self.updated_nodes }, - nodes_hashes_and_serialized + hashes_and_serialized_nodes .into_iter() .map(|(_, hash, serialized)| (hash, serialized)) .collect(), From 9c493f0fe1e0304c9ff468c017c2ae16102cb5f8 Mon Sep 17 00:00:00 2001 From: Longarithm Date: Thu, 19 Sep 2024 16:49:10 +0400 Subject: [PATCH 3/8] feedback 2 --- chain/chain/src/chain.rs | 5 +- core/store/src/trie/mem/resharding.rs | 316 +++++++++----------------- core/store/src/trie/mem/updating.rs | 8 +- 3 files changed, 114 insertions(+), 215 deletions(-) diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index 3ee23f367a3..77675c617b0 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -1899,10 +1899,11 @@ impl Chain { (children_shard_uids[1], RetainMode::Right), ] { let mut mem_tries = mem_tries.write().unwrap(); - let mut mem_trie_update = mem_tries.update(*chunk_extra.state_root(), false)?; + let mem_trie_update = mem_tries.update(*chunk_extra.state_root(), true)?; - let (trie_changes, partial_state) = + let (trie_changes, _) = mem_trie_update.retain_split_shard(boundary_account.clone(), retain_mode); + let partial_state = PartialState::default(); let partial_storage = PartialStorage { nodes: partial_state }; let mem_changes = trie_changes.mem_trie_changes.as_ref().unwrap(); let new_state_root = mem_tries.apply_memtrie_changes(block_height, mem_changes); diff --git a/core/store/src/trie/mem/resharding.rs b/core/store/src/trie/mem/resharding.rs index d434100b703..272aecc9a96 100644 --- a/core/store/src/trie/mem/resharding.rs +++ b/core/store/src/trie/mem/resharding.rs @@ -1,20 +1,10 @@ -use crate::trie::{MemTrieChanges, TrieRefcountDeltaMap}; use crate::{NibbleSlice, TrieChanges}; -use super::flexible_data::children::ChildrenView; -use super::flexible_data::value::ValueView; -use super::node::{MemTrieNodeId, MemTrieNodeView}; -use super::updating::{ - MemTrieUpdate, OldOrUpdatedNodeId, UpdatedMemTrieNode, UpdatedMemTrieNodeId, -}; -use super::{arena::ArenaMemory, node::MemTrieNodePtr}; +use super::arena::ArenaMemory; +use super::updating::{MemTrieUpdate, OldOrUpdatedNodeId, TrieAccesses, UpdatedMemTrieNode}; use itertools::Itertools; -use near_primitives::challenge::PartialState; -use near_primitives::hash::CryptoHash; use near_primitives::types::AccountId; -use std::collections::HashMap; use std::ops::Range; -use std::sync::Arc; /// Whether to retain left or right part of trie after shard split. pub enum RetainMode { @@ -33,38 +23,6 @@ enum RetainDecision { Descend, } -/// Tracks changes to the trie caused by the retain. -struct UpdatesTracker { - /// Accessed node hashes and their serializations. Used for proof - /// generation. - #[allow(unused)] - node_accesses: HashMap>, - /// All new nodes to be created. - updated_nodes: Vec, - /// On-disk changes applied to reference counts of nodes. In fact, these - /// are only increments. - refcount_changes: TrieRefcountDeltaMap, -} - -impl UpdatesTracker { - pub fn new() -> Self { - Self { - node_accesses: HashMap::new(), - updated_nodes: Vec::new(), - refcount_changes: TrieRefcountDeltaMap::new(), - } - } - - pub fn add_node(&mut self, node: UpdatedMemTrieNode) -> UpdatedMemTrieNodeId { - let id = self.updated_nodes.len(); - debug_assert!(node != UpdatedMemTrieNode::Empty); - self.updated_nodes.push(node); - // TODO(#12074): apply remaining changes to `updates_tracker`. - - id - } -} - impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { /// Splits the trie, separating entries by the boundary account. /// Leaves the left or right part of the trie, depending on the retain mode. @@ -72,10 +30,10 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { /// Returns the changes to be applied to in-memory trie and the proof of /// the split operation. pub fn retain_split_shard( - &'a mut self, + self, _boundary_account: AccountId, _retain_mode: RetainMode, - ) -> (TrieChanges, PartialState) { + ) -> (TrieChanges, TrieAccesses) { // TODO(#12074): generate intervals in nibbles. self.retain_multi_range(&[]) @@ -86,10 +44,8 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { /// /// Returns changes to be applied to in-memory trie and proof of the /// retain operation. - fn retain_multi_range( - &'a mut self, - intervals: &[Range>], - ) -> (TrieChanges, PartialState) { + fn retain_multi_range(mut self, intervals: &[Range>]) -> (TrieChanges, TrieAccesses) { + debug_assert!(intervals.iter().all(|range| range.start < range.end)); let intervals_nibbles = intervals .iter() .map(|range| { @@ -97,188 +53,130 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { ..NibbleSlice::new(&range.end).iter().collect_vec() }) .collect_vec(); - let mut updates_tracker = UpdatesTracker::new(); - let root = self.get_root().unwrap(); + // let root = self.get_root().unwrap(); // TODO(#12074): consider handling the case when no changes are made. - let _ = - retain_multi_range_recursive(root, vec![], &intervals_nibbles, &mut updates_tracker); - - let UpdatesTracker { updated_nodes, refcount_changes, .. } = updates_tracker; - // TODO(#12074): the next method requires more generic node structure. - // Consider simplifying the interface. - let ordered_nodes = (0..updated_nodes.len()).collect_vec(); - let updated_nodes = updated_nodes.into_iter().map(Some).collect_vec(); - let hashes_and_serialized_nodes = - self.compute_hashes_and_serialized_nodes(&ordered_nodes, &updated_nodes); - let node_ids_with_hashes = hashes_and_serialized_nodes - .iter() - .map(|(node_id, hash, _)| (*node_id, *hash)) - .collect(); - let memtrie_changes = MemTrieChanges { node_ids_with_hashes, updated_nodes }; - - let (trie_insertions, _) = TrieRefcountDeltaMap::into_changes(refcount_changes); - let trie_changes = TrieChanges { - // TODO(#12074): all the default fields are not used, consider - // using simpler struct. - old_root: CryptoHash::default(), - new_root: CryptoHash::default(), - insertions: trie_insertions, - deletions: Vec::default(), - mem_trie_changes: Some(memtrie_changes), - }; - // TODO(#12074): restore proof as well. - (trie_changes, PartialState::default()) - } -} - -/// Recursive implementation of the algorithm of retaining keys belonging to -/// any of the ranges given in `intervals` from the trie. -/// -/// `root` is the root of subtree being explored. -/// `key_nibbles` is the key corresponding to `root`. -/// `intervals_nibbles` is the list of ranges to be retained. -/// `updates_tracker` track changes to the trie caused by the retain. -/// -/// Returns id of the node after retain applied. -fn retain_multi_range_recursive<'a, M: ArenaMemory>( - root: MemTrieNodePtr<'a, M>, - key_nibbles: Vec, - intervals_nibbles: &[Range>], - updates_tracker: &mut UpdatesTracker, -) -> Option { - let decision = retain_decision(&key_nibbles, intervals_nibbles); - match decision { - RetainDecision::RetainAll => return Some(OldOrUpdatedNodeId::Old(root.id())), - RetainDecision::NoRetain => return None, - RetainDecision::Descend => {} + self.retain_multi_range_recursive(0, vec![], &intervals_nibbles); + self.to_trie_changes() } - let node_view = root.view(); - - let mut retain_in_branch = |children: &ChildrenView<'a, M>, value: Option<&ValueView>| { - retain_multi_range_in_branch( - root.id(), - children, - value, - key_nibbles.clone(), - intervals_nibbles, - updates_tracker, - ) - }; - - match node_view { - MemTrieNodeView::Leaf { extension, .. } => { - let extension = NibbleSlice::from_encoded(extension).0; - let full_key_nibbles = [key_nibbles, extension.iter().collect_vec()].concat(); - if !intervals_nibbles.iter().any(|interval| interval.contains(&full_key_nibbles)) { - None - } else { - Some(OldOrUpdatedNodeId::Old(root.id())) + /// Recursive implementation of the algorithm of retaining keys belonging to + /// any of the ranges given in `intervals` from the trie. + /// + /// `node_id` is the root of subtree being explored. + /// `key_nibbles` is the key corresponding to `root`. + /// `intervals_nibbles` is the list of ranges to be retained. + /// + /// Returns id of the node after retain applied. + fn retain_multi_range_recursive( + &mut self, + node_id: usize, + key_nibbles: Vec, + intervals_nibbles: &[Range>], + ) { + let decision = retain_decision(&key_nibbles, intervals_nibbles); + match decision { + RetainDecision::RetainAll => return, + RetainDecision::NoRetain => { + let _ = self.take_node(node_id); + self.place_node(node_id, UpdatedMemTrieNode::Empty); + return; } + RetainDecision::Descend => {} } - MemTrieNodeView::Branch { children, .. } => retain_in_branch(&children, None), - MemTrieNodeView::BranchWithValue { children, value, .. } => { - retain_in_branch(&children, Some(&value)) - } - MemTrieNodeView::Extension { extension, child, .. } => { - let extension_nibbles = NibbleSlice::from_encoded(extension).0.iter().collect_vec(); - let child_key = [key_nibbles, extension_nibbles].concat(); - let new_child = - retain_multi_range_recursive(child, child_key, intervals_nibbles, updates_tracker); - match new_child { - None => None, - Some(OldOrUpdatedNodeId::Old(id)) => Some(OldOrUpdatedNodeId::Old(id)), - Some(OldOrUpdatedNodeId::Updated(id)) => { - let new_node = UpdatedMemTrieNode::Extension { - extension: extension.to_vec().into_boxed_slice(), - child: OldOrUpdatedNodeId::Updated(id), - }; - Some(OldOrUpdatedNodeId::Updated(updates_tracker.add_node(new_node))) + let node = self.take_node(node_id); + match node { + UpdatedMemTrieNode::Empty => { + // Nowhere to descend. + self.place_node(node_id, UpdatedMemTrieNode::Empty); + return; + } + UpdatedMemTrieNode::Leaf { extension, value } => { + let full_key_nibbles = + [key_nibbles, NibbleSlice::from_encoded(&extension).0.iter().collect_vec()] + .concat(); + if !intervals_nibbles.iter().any(|interval| interval.contains(&full_key_nibbles)) { + self.place_node(node_id, UpdatedMemTrieNode::Empty); + } else { + self.place_node(node_id, UpdatedMemTrieNode::Leaf { extension, value }); } } - } - } -} + UpdatedMemTrieNode::Branch { mut children, mut value } => { + if !intervals_nibbles.iter().any(|interval| interval.contains(&key_nibbles)) { + value = None; + } -/// Helper function for `retain_multi_range_recursive` when subtree is rooted -/// at a branch. -fn retain_multi_range_in_branch<'a, M: ArenaMemory>( - root_id: MemTrieNodeId, - children: &ChildrenView<'a, M>, - mut value: Option<&ValueView>, - key_nibbles: Vec, - intervals_nibbles: &[Range>], - updates_tracker: &mut UpdatesTracker, -) -> Option { - let mut new_children = [None; 16]; - let mut changed = false; + for i in 0..16 { + let child = &mut children[i]; + let Some(old_child_id) = child.take() else { + continue; + }; - if !intervals_nibbles.iter().any(|interval| interval.contains(&key_nibbles)) { - value = None; - changed = true; - } + let new_child_id = self.ensure_updated(old_child_id); + let child_key_nibbles = [key_nibbles.clone(), vec![i as u8]].concat(); + self.retain_multi_range_recursive( + new_child_id, + child_key_nibbles, + intervals_nibbles, + ); + if self.updated_nodes[new_child_id] == Some(UpdatedMemTrieNode::Empty) { + *child = None; + } else { + *child = Some(OldOrUpdatedNodeId::Updated(new_child_id)); + } + } - for i in 0..16 { - let Some(child) = children.get(i) else { - continue; - }; + // TODO(#12074): squash the branch if needed. Consider reusing + // `squash_nodes`. - let child_key_nibbles = [key_nibbles.clone(), vec![i as u8]].concat(); - let new_child = retain_multi_range_recursive( - child, - child_key_nibbles, - intervals_nibbles, - updates_tracker, - ); - match new_child { - Some(OldOrUpdatedNodeId::Old(id)) => { - new_children[i] = Some(OldOrUpdatedNodeId::Old(id)); - } - Some(OldOrUpdatedNodeId::Updated(id)) => { - changed = true; - new_children[i] = Some(OldOrUpdatedNodeId::Updated(id)); + self.place_node(node_id, UpdatedMemTrieNode::Branch { children, value }); } - None => { - changed = true; - new_children[i] = None; + UpdatedMemTrieNode::Extension { extension, child } => { + let new_child_id = self.ensure_updated(child); + let extension_nibbles = + NibbleSlice::from_encoded(&extension).0.iter().collect_vec(); + let child_key = [key_nibbles, extension_nibbles].concat(); + self.retain_multi_range_recursive(new_child_id, child_key, intervals_nibbles); + + if self.updated_nodes[new_child_id] == Some(UpdatedMemTrieNode::Empty) { + self.place_node(node_id, UpdatedMemTrieNode::Empty); + } else { + self.place_node( + node_id, + UpdatedMemTrieNode::Extension { + extension, + child: OldOrUpdatedNodeId::Updated(new_child_id), + }, + ); + } } } } - - if changed { - let new_node = UpdatedMemTrieNode::Branch { - children: Box::new(new_children), - value: value.map(|v| v.to_flat_value()), - }; - // TODO(#12074): squash the branch if needed. - // TODO(#12074): return None if needed. - Some(OldOrUpdatedNodeId::Updated(updates_tracker.add_node(new_node))) - } else { - Some(OldOrUpdatedNodeId::Old(root_id)) - } } /// Based on the key and the intervals, makes decision on the subtree exploration. fn retain_decision(key: &[u8], intervals: &[Range>]) -> RetainDecision { let mut should_descend = false; for interval in intervals { - if key < interval.start.as_slice() { - if interval.start.starts_with(key) { - should_descend = true; - } else { - // No retain for this interval. - } - } else { - if key >= interval.end.as_slice() { - // No retain for this interval. - } else if interval.end.starts_with(key) { - should_descend = true; - } else { - return RetainDecision::RetainAll; - } + // If key can be extended to be equal to start or end of the interval, + // its subtree may have keys inside the interval. At the same time, + // it can be extended with bytes which would fall outside the interval. + // For example, if key is "a" and interval is "ab".."cd", subtree may + // contain "aa" which must be excluded. + if interval.start.starts_with(key) || interval.end.starts_with(key) { + should_descend = true; + continue; } + + // If key is not a prefix of boundaries and falls inside the interval, + // one can show that all the keys in the subtree are also inside the + // interval. + if interval.start.as_slice() <= key && key < interval.end.as_slice() { + return RetainDecision::RetainAll; + } + + // Otherwise, all the keys in the subtree are outside the interval. } if should_descend { @@ -332,7 +230,7 @@ mod tests { let memtrie_changes = update.to_mem_trie_changes_only(); let state_root = memtries.apply_memtrie_changes(0, &memtrie_changes); - let mut update = memtries.update(state_root, false).unwrap(); + let update = memtries.update(state_root, true).unwrap(); let (mut trie_changes, _) = update.retain_multi_range(&[retain_range]); let memtrie_changes = trie_changes.mem_trie_changes.take().unwrap(); let new_state_root = memtries.apply_memtrie_changes(1, &memtrie_changes); diff --git a/core/store/src/trie/mem/updating.rs b/core/store/src/trie/mem/updating.rs index aef3613e39e..24b31d35b5e 100644 --- a/core/store/src/trie/mem/updating.rs +++ b/core/store/src/trie/mem/updating.rs @@ -43,7 +43,7 @@ pub enum UpdatedMemTrieNode { } /// Keeps values and internal nodes accessed on updating memtrie. -pub(crate) struct TrieAccesses { +pub struct TrieAccesses { /// Hashes and encoded trie nodes. pub nodes: HashMap>, /// Hashes of accessed values - because values themselves are not @@ -148,12 +148,12 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { /// Internal function to take a node from the array of updated nodes, setting it /// to None. It is expected that place_node is then called to return the node to /// the same slot. - fn take_node(&mut self, index: UpdatedMemTrieNodeId) -> UpdatedMemTrieNode { + pub(crate) fn take_node(&mut self, index: UpdatedMemTrieNodeId) -> UpdatedMemTrieNode { self.updated_nodes.get_mut(index).unwrap().take().expect("Node taken twice") } /// Does the opposite of take_node; returns the node to the specified ID. - fn place_node(&mut self, index: UpdatedMemTrieNodeId, node: UpdatedMemTrieNode) { + pub(crate) fn place_node(&mut self, index: UpdatedMemTrieNodeId, node: UpdatedMemTrieNode) { assert!(self.updated_nodes[index].is_none(), "Node placed twice"); self.updated_nodes[index] = Some(node); } @@ -195,7 +195,7 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { } /// If the ID was old, converts it to an updated one. - fn ensure_updated(&mut self, node: OldOrUpdatedNodeId) -> UpdatedMemTrieNodeId { + pub(crate) fn ensure_updated(&mut self, node: OldOrUpdatedNodeId) -> UpdatedMemTrieNodeId { match node { OldOrUpdatedNodeId::Old(node_id) => self.convert_existing_to_updated(Some(node_id)), OldOrUpdatedNodeId::Updated(node_id) => node_id, From b32afb84ad175044037b06f678debd9038727dd2 Mon Sep 17 00:00:00 2001 From: Longarithm Date: Thu, 19 Sep 2024 17:49:17 +0400 Subject: [PATCH 4/8] minor --- core/store/src/trie/mem/resharding.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/store/src/trie/mem/resharding.rs b/core/store/src/trie/mem/resharding.rs index 272aecc9a96..53c660eceaa 100644 --- a/core/store/src/trie/mem/resharding.rs +++ b/core/store/src/trie/mem/resharding.rs @@ -28,7 +28,8 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { /// Leaves the left or right part of the trie, depending on the retain mode. /// /// Returns the changes to be applied to in-memory trie and the proof of - /// the split operation. + /// the split operation. Doesn't modifies trie itself, it's a caller's + /// responsibility to apply the changes. pub fn retain_split_shard( self, _boundary_account: AccountId, @@ -188,12 +189,12 @@ fn retain_decision(key: &[u8], intervals: &[Range>]) -> RetainDecision { // TODO(#12074): tests for // - multiple retain ranges +// - result is empty, or no changes are made // - removing keys one-by-one gives the same result as corresponding range retain // - `retain_split_shard` API // - all results of squashing branch // - checking not accessing not-inlined nodes // - proof correctness -// - (maybe) retain result is empty or complete tree #[cfg(test)] mod tests { use std::sync::Arc; From f4ae5458ae875d417f39477056cd3c1b034ad87d Mon Sep 17 00:00:00 2001 From: Longarithm Date: Thu, 19 Sep 2024 23:25:36 +0400 Subject: [PATCH 5/8] nit --- chain/chain/src/chain.rs | 1 - core/store/src/trie/mem/resharding.rs | 14 +++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index 77675c617b0..1be35c74e95 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -1951,7 +1951,6 @@ impl Chain { should_save_state_transition_data, )?; chain_update.commit()?; - Ok(new_head) } diff --git a/core/store/src/trie/mem/resharding.rs b/core/store/src/trie/mem/resharding.rs index 53c660eceaa..a4b8040b853 100644 --- a/core/store/src/trie/mem/resharding.rs +++ b/core/store/src/trie/mem/resharding.rs @@ -18,7 +18,7 @@ enum RetainDecision { /// Retain the whole subtree. RetainAll, /// The whole subtree is not retained. - NoRetain, + DiscardAll, /// Descend into all child subtrees. Descend, } @@ -78,12 +78,14 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { let decision = retain_decision(&key_nibbles, intervals_nibbles); match decision { RetainDecision::RetainAll => return, - RetainDecision::NoRetain => { + RetainDecision::DiscardAll => { let _ = self.take_node(node_id); self.place_node(node_id, UpdatedMemTrieNode::Empty); return; } - RetainDecision::Descend => {} + RetainDecision::Descend => { + // We need to descend into all children. The logic follows below. + } } let node = self.take_node(node_id); @@ -163,8 +165,10 @@ fn retain_decision(key: &[u8], intervals: &[Range>]) -> RetainDecision { // If key can be extended to be equal to start or end of the interval, // its subtree may have keys inside the interval. At the same time, // it can be extended with bytes which would fall outside the interval. + // // For example, if key is "a" and interval is "ab".."cd", subtree may - // contain "aa" which must be excluded. + // contain both "aa" which must be excluded and "ac" which must be + // retained. if interval.start.starts_with(key) || interval.end.starts_with(key) { should_descend = true; continue; @@ -183,7 +187,7 @@ fn retain_decision(key: &[u8], intervals: &[Range>]) -> RetainDecision { if should_descend { RetainDecision::Descend } else { - RetainDecision::NoRetain + RetainDecision::DiscardAll } } From d84fdb62b5bcbd940e963a7ba17150b8dbc828f6 Mon Sep 17 00:00:00 2001 From: Longarithm Date: Thu, 19 Sep 2024 23:28:07 +0400 Subject: [PATCH 6/8] nit --- core/store/src/trie/mem/updating.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/store/src/trie/mem/updating.rs b/core/store/src/trie/mem/updating.rs index 24b31d35b5e..817f986e6aa 100644 --- a/core/store/src/trie/mem/updating.rs +++ b/core/store/src/trie/mem/updating.rs @@ -141,10 +141,6 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { trie_update } - pub fn get_root(&self) -> Option> { - self.root.map(|id| id.as_ptr(self.memory)) - } - /// Internal function to take a node from the array of updated nodes, setting it /// to None. It is expected that place_node is then called to return the node to /// the same slot. From 9c840a1c47331ad36a28801ed1d9974df18edf3d Mon Sep 17 00:00:00 2001 From: Longarithm Date: Thu, 19 Sep 2024 23:35:00 +0400 Subject: [PATCH 7/8] nit --- core/store/src/trie/mem/updating.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/store/src/trie/mem/updating.rs b/core/store/src/trie/mem/updating.rs index 817f986e6aa..adfe6504bdc 100644 --- a/core/store/src/trie/mem/updating.rs +++ b/core/store/src/trie/mem/updating.rs @@ -1,7 +1,7 @@ use super::arena::{ArenaMemory, ArenaMut}; use super::flexible_data::children::ChildrenView; use super::metrics::MEM_TRIE_NUM_NODES_CREATED_FROM_UPDATES; -use super::node::{InputMemTrieNode, MemTrieNodeId, MemTrieNodePtr, MemTrieNodeView}; +use super::node::{InputMemTrieNode, MemTrieNodeId, MemTrieNodeView}; use crate::trie::{Children, MemTrieChanges, TrieRefcountDeltaMap, TRIE_COSTS}; use crate::{NibbleSlice, RawTrieNode, RawTrieNodeWithSize, TrieChanges}; use near_primitives::hash::{hash, CryptoHash}; From 8f0cc630da643aec588c9485b2b2833519c7ff2a Mon Sep 17 00:00:00 2001 From: Longarithm Date: Fri, 20 Sep 2024 14:04:01 +0400 Subject: [PATCH 8/8] feedback 3 --- core/store/src/trie/mem/resharding.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/core/store/src/trie/mem/resharding.rs b/core/store/src/trie/mem/resharding.rs index a4b8040b853..281e828abba 100644 --- a/core/store/src/trie/mem/resharding.rs +++ b/core/store/src/trie/mem/resharding.rs @@ -54,7 +54,7 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { ..NibbleSlice::new(&range.end).iter().collect_vec() }) .collect_vec(); - // let root = self.get_root().unwrap(); + // TODO(#12074): consider handling the case when no changes are made. // TODO(#12074): restore proof as well. self.retain_multi_range_recursive(0, vec![], &intervals_nibbles); @@ -62,13 +62,12 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { } /// Recursive implementation of the algorithm of retaining keys belonging to - /// any of the ranges given in `intervals` from the trie. + /// any of the ranges given in `intervals` from the trie. All changes are + /// applied in `updated_nodes`. /// /// `node_id` is the root of subtree being explored. /// `key_nibbles` is the key corresponding to `root`. /// `intervals_nibbles` is the list of ranges to be retained. - /// - /// Returns id of the node after retain applied. fn retain_multi_range_recursive( &mut self, node_id: usize, @@ -110,8 +109,7 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { value = None; } - for i in 0..16 { - let child = &mut children[i]; + for (i, child) in children.iter_mut().enumerate() { let Some(old_child_id) = child.take() else { continue; };