From 1be15df2d1d72d1bc8757e0bde5fb3a5b0ebe870 Mon Sep 17 00:00:00 2001 From: Koxiaet Date: Wed, 20 Jan 2021 06:49:50 +0000 Subject: [PATCH] Add Iter --- src/lib.rs | 148 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 114 insertions(+), 34 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 78bdcc3..ae96a9c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,11 +78,12 @@ pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal}; use std::cell::UnsafeCell; use std::fmt; +use std::iter::FusedIterator; use std::marker::PhantomData; use std::mem; use std::panic::UnwindSafe; use std::ptr; -use std::sync::atomic::{AtomicPtr, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}; use std::sync::Mutex; use thread_id::Thread; use unreachable::{UncheckedOptionExt, UncheckedResultExt}; @@ -104,13 +105,22 @@ const BUCKETS: usize = (POINTER_WIDTH + 1) as usize; pub struct ThreadLocal { /// The buckets in the thread local. The nth bucket contains `2^(n-1)` /// elements. Each bucket is lazily allocated. - buckets: [AtomicPtr>>; BUCKETS], + buckets: [AtomicPtr>; BUCKETS], + + /// The number of values in the thread local. This can be less than the real number of values, + /// but is never more. + values: AtomicUsize, /// Lock used to guard against concurrent modifications. This is taken when /// there is a possibility of allocating a new bucket, which only occurs - /// when inserting values. This also guards the counter for the total number - /// of values in the thread local. - lock: Mutex, + /// when inserting values. + lock: Mutex<()>, +} + +struct Entry { + present: AtomicBool, + // Use MaybeUninit once the MSRV has been bumped. + value: UnsafeCell>, } // ThreadLocal is always Sync, even if T isn't @@ -173,7 +183,8 @@ impl ThreadLocal { // Safety: AtomicPtr has the same representation as a pointer and arrays have the same // representation as a sequence of their inner type. buckets: unsafe { mem::transmute(buckets) }, - lock: Mutex::new(0), + values: AtomicUsize::new(0), + lock: Mutex::new(()), } } @@ -215,14 +226,13 @@ impl ThreadLocal { if bucket_ptr.is_null() { return None; } - unsafe { (&*(&*bucket_ptr.add(thread.index)).get()).as_ref() } + unsafe { (&*(*bucket_ptr.add(thread.index)).value.get()).as_ref() } } #[cold] fn insert(&self, thread: Thread, data: T) -> &T { // Lock the Mutex to ensure only a single thread is allocating buckets at once - let mut count = self.lock.lock().unwrap(); - *count += 1; + let _guard = self.lock.lock().unwrap(); let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) }; @@ -236,22 +246,40 @@ impl ThreadLocal { bucket_ptr }; - drop(count); + drop(_guard); // Insert the new element into the bucket - unsafe { - let value_ptr = (&*bucket_ptr.add(thread.index)).get(); - *value_ptr = Some(data); - (&*value_ptr).as_ref().unchecked_unwrap() + let entry = unsafe { &*bucket_ptr.add(thread.index) }; + let value_ptr = entry.value.get(); + unsafe { value_ptr.write(Some(data)) }; + entry.present.store(true, Ordering::Release); + + self.values.fetch_add(1, Ordering::Release); + + unsafe { (&*value_ptr).as_ref().unchecked_unwrap() } + } + + /// Returns an iterator over the local values of all threads in unspecified + /// order. + /// + /// This call can be done safely, as `T` is required to implement [`Sync`]. + pub fn iter(&self) -> Iter<'_, T> + where + T: Sync, + { + Iter { + thread_local: self, + yielded: 0, + bucket: 0, + bucket_size: 1, + index: 0, } } - fn raw_iter(&mut self) -> RawIter { - RawIter { - remaining: *self.lock.get_mut().unwrap(), - buckets: unsafe { - *(&self.buckets as *const _ as *const [*const UnsafeCell>; BUCKETS]) - }, + fn raw_iter_mut(&mut self) -> RawIterMut { + RawIterMut { + remaining: *self.values.get_mut(), + buckets: unsafe { *(&self.buckets as *const _ as *const [*mut Entry; BUCKETS]) }, bucket: 0, bucket_size: 1, index: 0, @@ -266,7 +294,7 @@ impl ThreadLocal { /// threads are currently accessing their associated values. pub fn iter_mut(&mut self) -> IterMut { IterMut { - raw: self.raw_iter(), + raw: self.raw_iter_mut(), marker: PhantomData, } } @@ -288,7 +316,7 @@ impl IntoIterator for ThreadLocal { fn into_iter(mut self) -> IntoIter { IntoIter { - raw: self.raw_iter(), + raw: self.raw_iter_mut(), _thread_local: self, } } @@ -319,15 +347,60 @@ impl fmt::Debug for ThreadLocal { impl UnwindSafe for ThreadLocal {} -struct RawIter { +/// Iterator over the contents of a `ThreadLocal`. +pub struct Iter<'a, T: Send + Sync + 'a> { + thread_local: &'a ThreadLocal, + yielded: usize, + bucket: usize, + bucket_size: usize, + index: usize, +} + +impl<'a, T: Send + Sync> Iterator for Iter<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + while self.bucket < BUCKETS { + let bucket = unsafe { self.thread_local.buckets.get_unchecked(self.bucket) }; + let bucket = bucket.load(Ordering::Relaxed); + + if !bucket.is_null() { + while self.index < self.bucket_size { + let entry = unsafe { &*bucket.add(self.index) }; + self.index += 1; + if entry.present.load(Ordering::Acquire) { + self.yielded += 1; + return Some(unsafe { (&*entry.value.get()).as_ref().unchecked_unwrap() }); + } + } + } + + if self.bucket != 0 { + self.bucket_size <<= 1; + } + self.bucket += 1; + + self.index = 0; + } + None + } + + fn size_hint(&self) -> (usize, Option) { + let total = self.thread_local.values.load(Ordering::Acquire); + (total - self.yielded, None) + } +} +impl<'a, T: Send + Sync> FusedIterator for Iter<'a, T> {} + +struct RawIterMut { remaining: usize, - buckets: [*const UnsafeCell>; BUCKETS], + buckets: [*mut Entry; BUCKETS], bucket: usize, bucket_size: usize, index: usize, } -impl Iterator for RawIter { +impl Iterator for RawIterMut { type Item = *mut Option; fn next(&mut self) -> Option { @@ -340,13 +413,11 @@ impl Iterator for RawIter { if !bucket.is_null() { while self.index < self.bucket_size { - let item = unsafe { (&*bucket.add(self.index)).get() }; - + let entry = unsafe { &mut *bucket.add(self.index) }; self.index += 1; - - if unsafe { &*item }.is_some() { + if *entry.present.get_mut() { self.remaining -= 1; - return Some(item); + return Some(entry.value.get()); } } } @@ -367,7 +438,7 @@ impl Iterator for RawIter { /// Mutable iterator over the contents of a `ThreadLocal`. pub struct IterMut<'a, T: Send + 'a> { - raw: RawIter, + raw: RawIterMut, marker: PhantomData<&'a mut ThreadLocal>, } @@ -389,7 +460,7 @@ impl<'a, T: Send + 'a> ExactSizeIterator for IterMut<'a, T> {} /// An iterator that moves out of a `ThreadLocal`. pub struct IntoIter { - raw: RawIter, + raw: RawIterMut, _thread_local: ThreadLocal, } @@ -409,10 +480,13 @@ impl Iterator for IntoIter { impl ExactSizeIterator for IntoIter {} -fn allocate_bucket(size: usize) -> *mut UnsafeCell> { +fn allocate_bucket(size: usize) -> *mut Entry { Box::into_raw( (0..size) - .map(|_| UnsafeCell::new(None::)) + .map(|_| Entry:: { + present: AtomicBool::new(false), + value: UnsafeCell::new(None), + }) .collect::>() .into_boxed_slice(), ) as *mut _ @@ -491,9 +565,15 @@ mod tests { .unwrap(); let mut tls = Arc::try_unwrap(tls).unwrap(); + + let mut v = tls.iter().map(|x| **x).collect::>(); + v.sort_unstable(); + assert_eq!(vec![1, 2, 3], v); + let mut v = tls.iter_mut().map(|x| **x).collect::>(); v.sort_unstable(); assert_eq!(vec![1, 2, 3], v); + let mut v = tls.into_iter().map(|x| *x).collect::>(); v.sort_unstable(); assert_eq!(vec![1, 2, 3], v);