From 8e9eb252f8d499459ab6374e04cac0a6f183f987 Mon Sep 17 00:00:00 2001 From: JustForFun88 Date: Sat, 25 Feb 2023 13:41:37 +0500 Subject: [PATCH] Initial implementation of `try_get_many` --- src/map.rs | 76 ++++++++++++- src/raw/array.rs | 270 +++++++++++++++++++++++++++++++++++++++++++++++ src/raw/mod.rs | 159 ++++++++++++++++++++++++++++ 3 files changed, 504 insertions(+), 1 deletion(-) create mode 100644 src/raw/array.rs diff --git a/src/map.rs b/src/map.rs index 57fba3046b..2dc7cf45de 100644 --- a/src/map.rs +++ b/src/map.rs @@ -1,4 +1,4 @@ -use crate::raw::{Allocator, Bucket, Global, RawDrain, RawIntoIter, RawIter, RawTable}; +use crate::raw::{Allocator, ArrayIter, Bucket, Global, RawDrain, RawIntoIter, RawIter, RawTable}; use crate::{Equivalent, TryReserveError}; use core::borrow::Borrow; use core::fmt::{self, Debug}; @@ -1668,6 +1668,40 @@ where .map(|res| res.map(|(k, v)| (&*k, v))) } + /// Attempts to get mutable references to `N` values in the map at once, with immutable + /// references to the corresponding keys. + /// + /// Returns an [`ArrayIter`] of length `N` with the results of each query. For soundness, + /// at most one mutable reference will be returned to any value. All duplicated keys will + /// be ignored. + /// + /// The order of elements in the returned iterator may not be the same as the order of + /// elements in lookup iterator `iter: &mut I`. + /// + /// Also, if `N` is less than the length of the iterator, the iterator will still be valid + /// and may continue to be used, in which case it will continue iterating from the element + /// remaining immediately after receiving `N` successful queries. + pub fn try_get_many_key_value_mut<'a, Q, I, const N: usize>( + &mut self, + iter: &mut I, + ) -> ArrayIter<(&K, &mut V), N> + where + I: Iterator, + Q: ?Sized + Hash + Equivalent + 'a, + { + let hash_builder = &self.hash_builder; + + let mut iter = iter.map(|key| { + ( + make_hash::(hash_builder, key), + equivalent_key::(key), + ) + }); + self.table + .try_get_many_mut(&mut iter) + .convert(|(ref k, v)| (k, v)) + } + /// Attempts to get mutable references to `N` values in the map at once, with immutable /// references to the corresponding keys, without validating that the values are unique. /// @@ -1723,6 +1757,46 @@ where .map(|res| res.map(|(k, v)| (&*k, v))) } + /// Attempts to get mutable references to `N` values in the map at once, with immutable + /// references to the corresponding keys. + /// + /// Returns an [`ArrayIter`] of length `N` with the results of each query. + /// + /// The order of the elements in the returned iterator is the same as the order of the + /// elements in the search iterator `iter: &mut I`, except for elements that were not found. + /// + /// Also, if `N` is less than the length of the iterator, the iterator will still be valid + /// and may continue to be used, in which case it will continue iterating from the element + /// remaining immediately after receiving `N` successful queries. + /// + /// # Safety + /// + /// Calling this method is *[undefined behavior]* if iterator contain overlapping + /// items that refer to the same `elements` in the table even if the resulting + /// references to `elements` in the table are not used. + /// + /// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html + pub unsafe fn try_get_many_key_value_unchecked_mut<'a, Q, I, const N: usize>( + &mut self, + iter: &mut I, + ) -> ArrayIter<(&K, &mut V), N> + where + I: Iterator, + Q: ?Sized + Hash + Equivalent + 'a, + { + let hash_builder = &self.hash_builder; + + let mut iter = iter.map(|key| { + ( + make_hash::(hash_builder, key), + equivalent_key::(key), + ) + }); + self.table + .try_get_many_mut_unchecked(&mut iter) + .convert(|(ref k, v)| (k, v)) + } + fn get_many_mut_inner( &mut self, ks: [&Q; N], diff --git a/src/raw/array.rs b/src/raw/array.rs new file mode 100644 index 0000000000..5c627318cc --- /dev/null +++ b/src/raw/array.rs @@ -0,0 +1,270 @@ +use crate::scopeguard::guard; +use core::iter::FusedIterator; +use core::mem; +use core::mem::MaybeUninit; +use core::ops::Range; +use core::ptr; + +/// Builder for incremental initialization of [`ArrayIter]. +/// +/// # Safety +/// +/// All write accesses to this structure are unsafe and must maintain a correct +/// count of `initialized` elements. +pub(crate) struct ArrayIterBuilder { + /// The array to be initialized. + array_mut: [MaybeUninit; N], + /// The number of items that have been initialized so far. + initialized: usize, +} + +impl ArrayIterBuilder { + /// Creates new [`ArrayIterBuilder`]. + #[inline] + pub(crate) fn new() -> Self { + ArrayIterBuilder { + // SAFETY: The `assume_init` is safe because the type we are claiming to have + // initialized here is a bunch of `MaybeUninit`s, which do not require initialization. + array_mut: unsafe { MaybeUninit::uninit().assume_init() }, + initialized: 0, + } + } + + /// Adds an item to the array and updates the initialized item counter. + /// + /// # Safety + /// + /// No more than `N` elements must be initialized. + #[inline] + pub(crate) unsafe fn push_unchecked(&mut self, item: T) { + // SAFETY: If `initialized` was correct before and the caller does not + // invoke this method more than `N` times then writes will be in-bounds + // and slots will not be initialized more than once. + self.array_mut + .get_unchecked_mut(self.initialized) + .write(item); + self.initialized += 1; + } + + /// Returns `true` if all elements have been initialized + #[inline] + pub(crate) fn is_initialized(&self) -> bool { + self.initialized == N + } + + /// Builds [`ArrayIter`] from [`ArrayIterBuilder`]. + #[inline] + pub(crate) fn build(self) -> ArrayIter { + let initialized = 0..self.initialized; + // SAFETY: We provide the number of elements that are guaranteed to be initialized + unsafe { ArrayIter::new_unchecked(self.array_mut, initialized) } + } +} + +/// A by-value [array] iterator. +pub struct ArrayIter { + data: [MaybeUninit; N], + alive: Range, +} + +impl ArrayIter { + const DATA_NEEDS_DROP: bool = mem::needs_drop::(); + + /// Creates an iterator over the elements in a partially-initialized buffer. + /// + /// # Safety + /// + /// - The `buffer[initialized]` elements must all be initialized. + /// - The range must be canonical, with `initialized.start <= initialized.end`. + /// - The range must be in-bounds for the buffer, with `initialized.end <= N`. + /// (Like how indexing `[0][100..100]` fails despite the range being empty.) + /// + /// It's sound to have more elements initialized than mentioned, though that + /// will most likely result in them being leaked. + #[inline] + const unsafe fn new_unchecked(buffer: [MaybeUninit; N], initialized: Range) -> Self { + // SAFETY: one of our safety conditions is that the range is canonical. + Self { + data: buffer, + alive: initialized, + } + } + + /// Returns an immutable slice of all elements that have not been yielded yet. + #[inline] + pub fn as_slice(&self) -> &[T] { + unsafe { + // SAFETY: We know that all elements within `alive` are properly initialized. + let slice = self.data.get_unchecked(self.alive.clone()); + // SAFETY: casting `slice` to a `*const [T]` is safe since the `slice` is initialized, + // and `MaybeUninit` is guaranteed to have the same layout as `T`. + // The pointer obtained is valid since it refers to memory owned by `slice` which is a + // reference and thus guaranteed to be valid for reads. + &*(slice as *const [MaybeUninit] as *const [T]) + } + } + + /// Returns a mutable slice of all elements that have not been yielded yet. + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [T] { + unsafe { + // SAFETY: We know that all elements within `alive` are properly initialized. + let slice = self.data.get_unchecked_mut(self.alive.clone()); + // SAFETY: casting `slice` to a `*mut [T]` is safe since the `slice` is initialized, + // and `MaybeUninit` is guaranteed to have the same layout as `T`. + // The pointer obtained is valid since it refers to memory owned by `slice` which is a + // reference and thus guaranteed to be valid for reads and writes. + &mut *(slice as *mut [MaybeUninit] as *mut [T]) + } + } + + /// Returns an [`ArrayIter`] of the same size as `self`, with function `f` + /// applied to each element in order. + pub fn convert(self, mut f: F) -> ArrayIter + where + F: FnMut(T) -> U, + { + let mut builder = ArrayIterBuilder::::new(); + if ArrayIter::::DATA_NEEDS_DROP { + // Function may panic, in which case we need to make sure that we drop the elements + // that have already been prodused. + let mut guard = guard(&mut builder, |self_| { + // SAFETY: + // 1. The `slice` will contain only initialized objects; + // 2. `MaybeUninit` is guaranteed to have the same size, alignment, and ABI as U + unsafe { + let slice: *mut [MaybeUninit] = + self_.array_mut.get_unchecked_mut(..self_.initialized); + ptr::drop_in_place(slice as *mut [U]); + } + }); + + for item in self { + // SAFETY: `self` length is equal to `builder/guard` length + unsafe { guard.push_unchecked(f(item)) } + } + mem::forget(guard); + } else { + for item in self { + // SAFETY: `self` length is equal to `builder` length + unsafe { builder.push_unchecked(f(item)) } + } + } + + builder.build() + } +} + +impl Drop for ArrayIter { + fn drop(&mut self) { + // SAFETY: This is safe: `as_mut_slice` returns exactly the sub-slice + // of elements that have not been moved out yet and that remain + // to be dropped. + unsafe { ptr::drop_in_place(self.as_mut_slice()) } + } +} + +impl Iterator for ArrayIter { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + // Get the next index from the front. + // + // Increasing `alive.start` by 1 maintains the invariant regarding + // `alive`. However, due to this change, for a short time, the alive + // zone is not `data[alive]` anymore, but `data[idx..alive.end]`. + // + // Avoid `Option::unwrap_or_else` because it bloats LLVM IR. + match self.alive.next() { + Some(idx) => { + // Read the element from the array. + // SAFETY: `idx` is an index into the former "alive" region of the + // array. Reading this element means that `data[idx]` is regarded as + // dead now (i.e. do not touch). As `idx` was the start of the + // alive-zone, the alive zone is now `data[alive]` again, restoring + // all invariants. + Some(unsafe { self.data.get_unchecked(idx).assume_init_read() }) + } + None => None, + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } + + #[inline] + fn fold(mut self, init: Acc, mut fold: Fold) -> Acc + where + Fold: FnMut(Acc, Self::Item) -> Acc, + { + let data = &mut self.data; + self.alive.by_ref().fold(init, |acc, idx| { + // SAFETY: idx is obtained by folding over the `alive` range, which implies the + // value is currently considered alive but as the range is being consumed each value + // we read here will only be read once and then considered dead. + fold(acc, unsafe { data.get_unchecked(idx).assume_init_read() }) + }) + } + + #[inline] + fn count(self) -> usize { + self.len() + } + + #[inline] + fn last(mut self) -> Option { + self.next_back() + } +} + +impl DoubleEndedIterator for ArrayIter { + #[inline] + fn next_back(&mut self) -> Option { + // Get the next index from the back. + // + // Decreasing `alive.end` by 1 maintains the invariant regarding + // `alive`. However, due to this change, for a short time, the alive + // zone is not `data[alive]` anymore, but `data[alive.start..=idx]`. + // + // Avoid `Option::unwrap_or_else` because it bloats LLVM IR. + match self.alive.next_back() { + Some(idx) => { + // Read the element from the array. + // SAFETY: `idx` is an index into the former "alive" region of the + // array. Reading this element means that `data[idx]` is regarded as + // dead now (i.e. do not touch). As `idx` was the end of the + // alive-zone, the alive zone is now `data[alive]` again, restoring + // all invariants. + Some(unsafe { self.data.get_unchecked(idx).assume_init_read() }) + } + None => None, + } + } + + #[inline] + fn rfold(mut self, init: Acc, mut rfold: Fold) -> Acc + where + Fold: FnMut(Acc, Self::Item) -> Acc, + { + let data = &mut self.data; + self.alive.by_ref().rfold(init, |acc, idx| { + // SAFETY: idx is obtained by folding over the `alive` range, which implies the + // value is currently considered alive but as the range is being consumed each value + // we read here will only be read once and then considered dead. + rfold(acc, unsafe { data.get_unchecked(idx).assume_init_read() }) + }) + } +} + +impl ExactSizeIterator for ArrayIter { + #[inline] + fn len(&self) -> usize { + self.alive.len() + } +} + +impl FusedIterator for ArrayIter {} diff --git a/src/raw/mod.rs b/src/raw/mod.rs index e86bd239a2..52e1b814e6 100644 --- a/src/raw/mod.rs +++ b/src/raw/mod.rs @@ -1,6 +1,7 @@ use crate::alloc::alloc::{handle_alloc_error, Layout}; use crate::scopeguard::{guard, ScopeGuard}; use crate::TryReserveError; +use core::cmp::Ordering; use core::iter::FusedIterator; use core::marker::PhantomData; use core::mem; @@ -40,6 +41,11 @@ mod bitmask; use self::bitmask::{BitMask, BitMaskIter}; use self::imp::Group; +mod array; + +pub use self::array::ArrayIter; +use self::array::ArrayIterBuilder; + // Branch prediction hint. This is currently only available on nightly but it // consistently improves performance by 10-15%. #[cfg(feature = "nightly")] @@ -289,6 +295,32 @@ pub struct Bucket { // never exposed in a public API. unsafe impl Send for Bucket {} +impl PartialEq for Bucket { + #[inline] + fn eq(&self, other: &Self) -> bool { + // Comparing `ptr: NonNull` directly since T can be ZST + self.ptr.as_ptr() == other.ptr.as_ptr() + } +} + +impl Eq for Bucket {} + +impl Ord for Bucket { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + // Comparing `ptr: NonNull` directly since T can be ZST + self.ptr.as_ptr().cmp(&other.ptr.as_ptr()) + } +} + +impl PartialOrd for Bucket { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + // Comparing `ptr: NonNull` directly since T can be ZST + self.ptr.as_ptr().partial_cmp(&other.ptr.as_ptr()) + } +} + impl Clone for Bucket { #[inline] fn clone(&self) -> Self { @@ -1274,6 +1306,133 @@ impl RawTable { } } + /// Attempts to get mutable references to `N` entries in the table at once using + /// `hash` and equality function from iterator. + /// + /// Returns an [`ArrayIter`] of length `N` with the results of each query. + /// + /// At most one mutable reference will be returned to any entry. + /// Duplicate values will be skipped. + /// + /// The `iter` argument should be an iterator that return `hash` of the stored + /// `element` and closure for checking the equivalence of that `element`. + /// + /// The order of elements in the returned iterator may not be the same as the order of + /// elements in lookup iterator `iter: &mut I`. + /// + /// Also, if `N` is less than the length of the iterator, the iterator will still be valid + /// and may continue to be used, in which case it will continue iterating from the element + /// remaining immediately after receiving `N` successful queries. + pub(crate) fn try_get_many_mut<'a, I, F, const N: usize>( + &'a mut self, + iter: &mut I, + ) -> ArrayIter<&'a mut T, N> + where + I: Iterator, + F: FnMut(&T) -> bool, + { + if N == 0 { + // SAFETY: An empty array is always inhabited and has no validity invariants. + return ArrayIterBuilder::<&mut T, N>::new().build(); + } + + let mut builder = ArrayIterBuilder::, N>::new(); + + for (hash, eq) in iter { + if let Some(bucket) = self.find(hash, eq) { + // SAFETY: + // 1. `N` is greater than 0, which was checked above + // 2. We break the loop immediately after initializing all elements + // (it's okay if we haven't initialized all the elements, but overflow + // must not be allowed) + unsafe { builder.push_unchecked(bucket) } + if builder.is_initialized() { + break; + } + } + } + + let mut array_iter = builder.build(); + array_iter.as_mut_slice().sort_unstable(); + + let mut out_builder = ArrayIterBuilder::<&mut T, N>::new(); + + if let Some(mut cur_ptr) = array_iter.next() { + // SAFETY: + // 1. `N` is greater than 0, which was checked above + // 2. One pointer is always unique. + // 3. We got all buckets from the `find` function, so they are valid. + unsafe { out_builder.push_unchecked(cur_ptr.as_mut()) }; + + for next_ptr in array_iter { + if cur_ptr != next_ptr { + // SAFETY: + // 1. The `array_iter` length is equal or less than `out_guard` length. + // 2. We have just verified that this is a unique pointer that does not + // repeat within the array. + // 3. We got all buckets from the `find` function, so they are valid. + unsafe { out_builder.push_unchecked(next_ptr.as_mut()) }; + cur_ptr = next_ptr; + } + } + } + out_builder.build() + } + + /// Attempts to get mutable references to `N` entries in the table at once using + /// `hash` and equality function from iterator. + /// + /// Returns an [`ArrayIter`] of length `N` with the results of each query. + /// + /// The `iter` argument should be an iterator that return `hash` of the stored + /// `element` and closure for checking the equivalence of that `element`. + /// + /// The order of the elements in the returned iterator is the same as the order of the + /// elements in the search iterator `iter: &mut I`, except for elements that were not found. + /// + /// Also, if `N` is less than the length of the iterator, the iterator will still be valid + /// and may continue to be used, in which case it will continue iterating from the element + /// remaining immediately after receiving `N` successful queries. + /// + /// # Safety + /// + /// Calling this method is *[undefined behavior]* if iterator contain overlapping + /// items that refer to the same `elements` in the table even if the resulting + /// references to `elements` in the table are not used. + /// + /// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html + pub(crate) unsafe fn try_get_many_mut_unchecked<'a, I, F, const N: usize>( + &'a mut self, + iter: &mut I, + ) -> ArrayIter<&'a mut T, N> + where + I: Iterator, + F: FnMut(&T) -> bool, + { + if N == 0 { + // SAFETY: An empty array is always inhabited and has no validity invariants. + return ArrayIterBuilder::<&mut T, N>::new().build(); + } + + let mut builder = ArrayIterBuilder::<&mut T, N>::new(); + + for (hash, eq) in iter { + if let Some(bucket) = self.find(hash, eq) { + // SAFETY: + // 1. `N` is greater than 0, which was checked above + // 2. We break the loop immediately after initializing all elements + // (it's okay if we haven't initialized all the elements, but overflow + // must not be allowed) + // 3. The caller must uphold the safety contract for `try_get_many_mut_unchecked`. + unsafe { builder.push_unchecked(bucket.as_mut()) } + if builder.is_initialized() { + break; + } + } + } + builder.build() + } + /// Attempts to get mutable references to `N` entries in the table at once. /// /// Returns an array of length `N` with the results of each query.