Skip to content

Commit

Permalink
store: replace TrieIterator::seek by seek_prefix (#7585)
Browse files Browse the repository at this point in the history
Replace TrieIterator::seek method with TrieIterator::seek_prefix
which, as name suggests, limits the traversal to keys which match
given prefix.  That is, where seek used to return all keys no less
than the query, seek_prefix will now further limit that set to keys
which start with the query.

Issue: #2076
  • Loading branch information
blasrodri authored Sep 9, 2022
1 parent 947af6b commit a59462c
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 39 deletions.
70 changes: 53 additions & 17 deletions core/store/src/trie/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{StorageError, Trie};
struct Crumb {
node: TrieNodeWithSize,
status: CrumbStatus,
prefix_boundary: bool,
}

#[derive(Clone, Copy, Eq, PartialEq, Debug)]
Expand All @@ -20,6 +21,10 @@ pub(crate) enum CrumbStatus {

impl Crumb {
fn increment(&mut self) {
if self.prefix_boundary {
self.status = CrumbStatus::Exiting;
return;
}
self.status = match (&self.status, &self.node.node) {
(_, &TrieNode::Empty) => CrumbStatus::Exiting,
(&CrumbStatus::Entering, _) => CrumbStatus::At,
Expand Down Expand Up @@ -62,26 +67,44 @@ impl<'a> TrieIterator<'a> {
}

/// Position the iterator on the first element with key => `key`.
pub fn seek<K: AsRef<[u8]>>(&mut self, key: K) -> Result<(), StorageError> {
self.seek_nibble_slice(NibbleSlice::new(key.as_ref())).map(drop)
pub fn seek_prefix<K: AsRef<[u8]>>(&mut self, key: K) -> Result<(), StorageError> {
self.seek_nibble_slice(NibbleSlice::new(key.as_ref()), true).map(drop)
}

/// Returns the hash of the last node
pub(crate) fn seek_nibble_slice(
&mut self,
mut key: NibbleSlice<'_>,
is_prefix_seek: bool,
) -> Result<CryptoHash, StorageError> {
self.trail.clear();
self.key_nibbles.clear();
// Checks if a key in an extension or leaf matches our search query.
//
// When doing prefix seek, this checks whether `key` is a prefix of
// `ext_key`. When doing regular range seek, this checks whether `key`
// is no greater than `ext_key`. If those conditions aren’t met, the
// node with `ext_key` should not match our query.
let check_ext_key = |key: &NibbleSlice, ext_key: &NibbleSlice| {
if is_prefix_seek {
ext_key.starts_with(key)
} else {
ext_key >= key
}
};

let mut hash = self.trie.root;
let mut prev_prefix_boundary = &mut false;
loop {
*prev_prefix_boundary = is_prefix_seek;
self.descend_into_node(&hash)?;
let Crumb { status, node } = self.trail.last_mut().unwrap();
let Crumb { status, node, prefix_boundary } = self.trail.last_mut().unwrap();
prev_prefix_boundary = prefix_boundary;
match &node.node {
TrieNode::Empty => break,
TrieNode::Leaf(leaf_key, _) => {
let existing_key = NibbleSlice::from_encoded(leaf_key).0;
if existing_key < key {
if !check_ext_key(&key, &existing_key) {
self.key_nibbles.extend(existing_key.iter());
*status = CrumbStatus::Exiting;
}
Expand All @@ -98,6 +121,7 @@ impl<'a> TrieIterator<'a> {
hash = *child.unwrap_hash();
key = key.mid(1);
} else {
*prefix_boundary = is_prefix_seek;
break;
}
}
Expand All @@ -110,7 +134,7 @@ impl<'a> TrieIterator<'a> {
*status = CrumbStatus::At;
self.key_nibbles.extend(existing_key.iter());
} else {
if existing_key < key {
if !check_ext_key(&key, &existing_key) {
*status = CrumbStatus::Exiting;
self.key_nibbles.extend(existing_key.iter());
}
Expand All @@ -127,7 +151,7 @@ impl<'a> TrieIterator<'a> {
/// The node is stored as the last [`Crumb`] in the trail.
fn descend_into_node(&mut self, hash: &CryptoHash) -> Result<(), StorageError> {
let node = self.trie.retrieve_node(hash)?.1;
self.trail.push(Crumb { status: CrumbStatus::Entering, node });
self.trail.push(Crumb { status: CrumbStatus::Entering, node, prefix_boundary: false });
Ok(())
}

Expand Down Expand Up @@ -227,7 +251,7 @@ impl<'a> TrieIterator<'a> {
path_end: &[u8],
) -> Result<Vec<TrieItem>, StorageError> {
let path_begin_encoded = NibbleSlice::encode_nibbles(path_begin, false);
self.seek_nibble_slice(NibbleSlice::from_encoded(&path_begin_encoded).0)?;
self.seek_nibble_slice(NibbleSlice::from_encoded(&path_begin_encoded).0, false)?;

let mut trie_items = vec![];
for item in self {
Expand All @@ -250,7 +274,8 @@ impl<'a> TrieIterator<'a> {
path_end: &[u8],
) -> Result<Vec<TrieTraversalItem>, StorageError> {
let path_begin_encoded = NibbleSlice::encode_nibbles(path_begin, true);
let last_hash = self.seek_nibble_slice(NibbleSlice::from_encoded(&path_begin_encoded).0)?;
let last_hash =
self.seek_nibble_slice(NibbleSlice::from_encoded(&path_begin_encoded).0, false)?;
let mut prefix = Self::common_prefix(path_end, &self.key_nibbles);
if self.key_nibbles[prefix..] >= path_end[prefix..] {
return Ok(vec![]);
Expand Down Expand Up @@ -371,15 +396,15 @@ mod tests {
let result2: Vec<_> = map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
assert_eq!(result1, result2);
}
test_seek(&trie, &map, &[]);
test_seek_prefix(&trie, &map, &[]);

let empty_vec = vec![];
let max_key = map.keys().max().unwrap_or(&empty_vec);
let min_key = map.keys().min().unwrap_or(&empty_vec);
test_get_trie_items(&trie, &map, &[], &[]);
test_get_trie_items(&trie, &map, min_key, max_key);
for (seek_key, _) in trie_changes.iter() {
test_seek(&trie, &map, seek_key);
test_seek_prefix(&trie, &map, seek_key);
test_get_trie_items(&trie, &map, min_key, seek_key);
test_get_trie_items(&trie, &map, seek_key, max_key);
}
Expand All @@ -388,7 +413,7 @@ mod tests {
let key_length = rng.gen_range(1, 8);
let seek_key: Vec<u8> =
(0..key_length).map(|_| *alphabet.choose(&mut rng).unwrap()).collect();
test_seek(&trie, &map, &seek_key);
test_seek_prefix(&trie, &map, &seek_key);

let seek_key2: Vec<u8> =
(0..key_length).map(|_| *alphabet.choose(&mut rng).unwrap()).collect();
Expand Down Expand Up @@ -422,13 +447,24 @@ mod tests {
assert_eq!(result1, result2);
}

fn test_seek(trie: &Trie, map: &BTreeMap<Vec<u8>, Vec<u8>>, seek_key: &[u8]) {
fn test_seek_prefix(trie: &Trie, map: &BTreeMap<Vec<u8>, Vec<u8>>, seek_key: &[u8]) {
let mut iterator = trie.iter().unwrap();
iterator.seek(&seek_key).unwrap();
let result1: Vec<_> = iterator.map(Result::unwrap).take(5).collect();
let result2: Vec<_> =
map.range(seek_key.to_vec()..).map(|(k, v)| (k.clone(), v.clone())).take(5).collect();
assert_eq!(result1, result2);
iterator.seek_prefix(&seek_key).unwrap();
let mut got = Vec::with_capacity(5);
for item in iterator {
let (key, value) = item.unwrap();
assert!(key.starts_with(seek_key), "‘{key:x?}’ does not start with ‘{seek_key:x?}’");
if got.len() < 5 {
got.push((key, value));
}
}
let want: Vec<_> = map
.range(seek_key.to_vec()..)
.map(|(k, v)| (k.clone(), v.clone()))
.take(5)
.filter(|(x, _)| x.starts_with(seek_key))
.collect();
assert_eq!(got, want);
}

#[test]
Expand Down
28 changes: 19 additions & 9 deletions core/store/src/trie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ enum RawTrieNode {
/// memory_usage is serialized, stored, and contributes to hash
#[derive(Debug, Eq, PartialEq)]
struct RawTrieNodeWithSize {
node: RawTrieNode,
pub node: RawTrieNode,
memory_usage: u64,
}

Expand Down Expand Up @@ -1077,9 +1077,19 @@ mod tests {
}
assert_eq!(pairs, iter_pairs);

let assert_has_next = |want, other_iter: &mut TrieIterator| {
assert_eq!(Some(want), other_iter.next().map(|item| item.unwrap().0).as_deref());
};

let mut other_iter = trie.iter().unwrap();
other_iter.seek(b"r").unwrap();
assert_eq!(other_iter.next().unwrap().unwrap().0, b"x".to_vec());
other_iter.seek_prefix(b"r").unwrap();
assert_eq!(other_iter.next(), None);
other_iter.seek_prefix(b"x").unwrap();
assert_has_next(b"x", &mut other_iter);
assert_eq!(other_iter.next(), None);
other_iter.seek_prefix(b"y").unwrap();
assert_has_next(b"y", &mut other_iter);
assert_eq!(other_iter.next(), None);
}

#[test]
Expand Down Expand Up @@ -1131,13 +1141,13 @@ mod tests {
let root = test_populate_trie(&tries, &Trie::EMPTY_ROOT, ShardUId::single_shard(), changes);
let trie = tries.get_trie_for_shard(ShardUId::single_shard(), root.clone());
let mut iter = trie.iter().unwrap();
iter.seek(&vec![0, 116, 101, 115, 116, 44]).unwrap();
iter.seek_prefix(&[0, 116, 101, 115, 116, 44]).unwrap();
let mut pairs = vec![];
for pair in iter {
pairs.push(pair.unwrap().0);
}
assert_eq!(
pairs[..2],
pairs,
[
vec![
0, 116, 101, 115, 116, 44, 98, 97, 108, 97, 110, 99, 101, 115, 58, 98, 111, 98,
Expand Down Expand Up @@ -1219,7 +1229,7 @@ mod tests {
}

#[test]
fn test_iterator_seek() {
fn test_iterator_seek_prefix() {
let mut rng = rand::thread_rng();
for _test_run in 0..10 {
let tries = create_tries();
Expand All @@ -1237,7 +1247,7 @@ mod tests {
if let Some(value) = value {
let want = Some(Ok((key.clone(), value)));
let mut iterator = trie.iter().unwrap();
iterator.seek(&key).unwrap();
iterator.seek_prefix(&key).unwrap();
assert_eq!(want, iterator.next(), "key: {key:x?}");
}
}
Expand All @@ -1246,9 +1256,9 @@ mod tests {
let queries = gen_changes(&mut rng, 500).into_iter().map(|(key, _)| key);
for query in queries {
let mut iterator = trie.iter().unwrap();
iterator.seek(&query).unwrap();
iterator.seek_prefix(&query).unwrap();
if let Some(Ok((key, _))) = iterator.next() {
assert!(key >= query);
assert!(key.starts_with(&query), "‘{key:x?}’ does not start with ‘{query:x?}’");
}
}
}
Expand Down
18 changes: 10 additions & 8 deletions core/store/src/trie/state_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ impl Trie {
if part_id.idx + 1 != part_id.total {
let mut iterator = self.iter()?;
let path_end_encoded = NibbleSlice::encode_nibbles(&path_end, false);
iterator.seek_nibble_slice(NibbleSlice::from_encoded(&path_end_encoded[..]).0)?;
iterator
.seek_nibble_slice(NibbleSlice::from_encoded(&path_end_encoded[..]).0, false)?;
if let Some(item) = iterator.next() {
item?;
}
Expand Down Expand Up @@ -99,14 +100,14 @@ impl Trie {
}
TrieNode::Branch(children, _) => {
for child_index in 0..children.len() {
let (_, child) = match &children[child_index] {
let child = match &children[child_index] {
None => {
continue;
}
Some(NodeHandle::InMemory(_)) => {
unreachable!("only possible while mutating")
}
Some(NodeHandle::Hash(h)) => self.retrieve_node(h)?,
Some(NodeHandle::Hash(h)) => self.retrieve_node(h)?.1,
};
if *size_skipped + child.memory_usage <= size_start {
*size_skipped += child.memory_usage;
Expand Down Expand Up @@ -134,9 +135,9 @@ impl Trie {
Ok(false)
}
TrieNode::Extension(key, child_handle) => {
let (_, child) = match child_handle {
let child = match child_handle {
NodeHandle::InMemory(_) => unreachable!("only possible while mutating"),
NodeHandle::Hash(h) => self.retrieve_node(h)?,
NodeHandle::Hash(h) => self.retrieve_node(h)?.1,
};
let (slice, _is_leaf) = NibbleSlice::from_encoded(key);
key_nibbles.extend(slice.iter());
Expand Down Expand Up @@ -353,7 +354,7 @@ mod tests {
}
if i < 16 {
if let Some(NodeHandle::Hash(h)) = children[i].clone() {
let (_, child) = self.retrieve_node(&h)?;
let child = self.retrieve_node(&h)?.1;
stack.push((hash, node, CrumbStatus::AtChild(i + 1)));
stack.push((h, child, CrumbStatus::Entering));
} else {
Expand All @@ -377,7 +378,7 @@ mod tests {
unreachable!("only possible while mutating")
}
NodeHandle::Hash(h) => {
let (_, child) = self.retrieve_node(&h)?;
let child = self.retrieve_node(&h)?.1;
stack.push((hash, node, CrumbStatus::Exiting));
stack.push((h, child, CrumbStatus::Entering));
}
Expand All @@ -400,7 +401,8 @@ mod tests {

let mut iterator = self.iter()?;
let path_begin_encoded = NibbleSlice::encode_nibbles(&path_begin, false);
iterator.seek_nibble_slice(NibbleSlice::from_encoded(&path_begin_encoded[..]).0)?;
iterator
.seek_nibble_slice(NibbleSlice::from_encoded(&path_begin_encoded[..]).0, false)?;
loop {
match iterator.next() {
None => break,
Expand Down
3 changes: 2 additions & 1 deletion core/store/src/trie/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ impl<'a> TrieUpdateIterator<'a> {
}
None => None,
};
trie_iter.seek(&start_offset)?;
trie_iter.seek_prefix(&start_offset)?;
let committed_iter = state_update.committed.range(start_offset.clone()..).map(
|(raw_key, changes_with_trie_key)| {
(
Expand Down Expand Up @@ -262,6 +262,7 @@ impl<'a> Iterator for TrieUpdateIterator<'a> {

fn next(&mut self) -> Option<Self::Item> {
let stop_cond = |key: &[u8], prefix: &[u8], end_offset: &Option<Vec<u8>>| {
// TODO(mina86): Figure out if starts_with check is still necessary.
!key.starts_with(prefix) || end_offset.as_deref().map_or(false, |end| key >= end)
};
enum Ordering {
Expand Down
5 changes: 1 addition & 4 deletions runtime/runtime/src/state_viewer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,9 @@ impl TrieViewer {
let query = trie_key_parsers::get_raw_prefix_for_contract_data(account_id, prefix);
let acc_sep_len = query.len() - prefix.len();
let mut iter = state_update.trie().iter()?;
iter.seek(&query)?;
iter.seek_prefix(&query)?;
for item in iter {
let (key, value) = item?;
if !key.starts_with(query.as_ref()) {
break;
}
values.push(StateItem {
key: key[acc_sep_len..].to_vec(),
value: value,
Expand Down

0 comments on commit a59462c

Please sign in to comment.