diff --git a/src/map.rs b/src/map.rs index fcd0e53d5a..5bb4a9d358 100644 --- a/src/map.rs +++ b/src/map.rs @@ -956,6 +956,35 @@ where } } } + /// Drains elements which are false under the given predicate, + /// and returns an iterator over the removed items. + /// + /// In other words, move all pairs `(k, v)` such that `f(&k,&mut v)` returns `false` out + /// into another iterator. + /// + /// When the returned DrainedFilter is dropped, the elements that don't satisfy + /// the predicate are dropped from the table. + /// + /// # Examples + /// + /// ``` + /// use hashbrown::HashMap; + /// + /// let mut map: HashMap = (0..8).map(|x|(x, x*10)).collect(); + /// let drained = map.drain_filter(|&k, _| k % 2 == 0); + /// assert_eq!(drained.count(), 4); + /// assert_eq!(map.len(), 4); + /// ``` + pub fn drain_filter(&mut self, f: F) -> DrainFilter<'_, K, V, F> + where + F: FnMut(&K, &mut V) -> bool, + { + DrainFilter { + f, + iter: unsafe { self.table.iter() }, + table: &mut self.table, + } + } } impl HashMap { @@ -1236,6 +1265,66 @@ impl Drain<'_, K, V> { } } +/// A draining iterator over entries of a `HashMap` which don't satisfy the predicate `f`. +/// +/// This `struct` is created by the [`drain_filter`] method on [`HashMap`]. See its +/// documentation for more. +/// +/// [`drain_filter`]: struct.HashMap.html#method.drain_filter +/// [`HashMap`]: struct.HashMap.html +pub struct DrainFilter<'a, K, V, F> +where + F: FnMut(&K, &mut V) -> bool, +{ + f: F, + iter: RawIter<(K, V)>, + table: &'a mut RawTable<(K, V)>, +} + +impl<'a, K, V, F> Drop for DrainFilter<'a, K, V, F> +where + F: FnMut(&K, &mut V) -> bool, +{ + fn drop(&mut self) { + struct DropGuard<'r, 'a, K, V, F>(&'r mut DrainFilter<'a, K, V, F>) + where + F: FnMut(&K, &mut V) -> bool; + + impl<'r, 'a, K, V, F> Drop for DropGuard<'r, 'a, K, V, F> + where + F: FnMut(&K, &mut V) -> bool, + { + fn drop(&mut self) { + while let Some(_) = self.0.next() {} + } + } + while let Some(item) = self.next() { + let guard = DropGuard(self); + drop(item); + mem::forget(guard); + } + } +} + +impl Iterator for DrainFilter<'_, K, V, F> +where + F: FnMut(&K, &mut V) -> bool, +{ + type Item = (K, V); + fn next(&mut self) -> Option { + unsafe { + while let Some(item) = self.iter.next() { + let &mut (ref key, ref mut value) = item.as_mut(); + if !(self.f)(key, value) { + self.table.erase_no_drop(&item); + return Some(item.read()); + } + } + } + None + } +} + /// A mutable iterator over the values of a `HashMap`. /// /// This `struct` is created by the [`values_mut`] method on [`HashMap`]. See its @@ -3488,6 +3577,23 @@ mod test_map { assert_eq!(map[&6], 60); } + #[test] + fn test_drain_filter() { + { + let mut map: HashMap = (0..8).map(|x| (x, x * 10)).collect(); + let drained = map.drain_filter(|&k, _| k % 2 == 0); + let mut out = drained.collect::>(); + out.sort_unstable(); + assert_eq!(vec![(1, 10), (3, 30), (5, 50), (7, 70)], out); + assert_eq!(map.len(), 4); + } + { + let mut map: HashMap = (0..8).map(|x| (x, x * 10)).collect(); + drop(map.drain_filter(|&k, _| k % 2 == 0)); + assert_eq!(map.len(), 4); + } + } + #[test] #[cfg_attr(miri, ignore)] // FIXME: no OOM signalling (https://github.com/rust-lang/miri/issues/613) fn test_try_reserve() {