From 810c043ff71ef66b96ca5d92319df4aa7134bf44 Mon Sep 17 00:00:00 2001 From: Kestrer Date: Sun, 24 Jan 2021 06:56:40 +0000 Subject: [PATCH] Implement iterator logic in RawIter --- src/lib.rs | 182 +++++++++++++++++++++++------------------------------ 1 file changed, 80 insertions(+), 102 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 16eadb6..792a5b2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,7 +76,6 @@ pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal}; use std::cell::UnsafeCell; use std::fmt; use std::iter::FusedIterator; -use std::marker::PhantomData; use std::mem; use std::mem::MaybeUninit; use std::panic::UnwindSafe; @@ -274,20 +273,7 @@ impl ThreadLocal { { Iter { thread_local: self, - yielded: 0, - bucket: 0, - bucket_size: 1, - index: 0, - } - } - - fn raw_iter_mut(&mut self) -> RawIterMut { - RawIterMut { - remaining: *self.values.get_mut(), - buckets: unsafe { *(&self.buckets as *const _ as *const [*mut Entry; BUCKETS]) }, - bucket: 0, - bucket_size: 1, - index: 0, + raw: RawIter::new(), } } @@ -299,8 +285,8 @@ impl ThreadLocal { /// threads are currently accessing their associated values. pub fn iter_mut(&mut self) -> IterMut { IterMut { - raw: self.raw_iter_mut(), - marker: PhantomData, + thread_local: self, + raw: RawIter::new(), } } @@ -319,10 +305,10 @@ impl IntoIterator for ThreadLocal { type Item = T; type IntoIter = IntoIter; - fn into_iter(mut self) -> IntoIter { + fn into_iter(self) -> IntoIter { IntoIter { - raw: self.raw_iter_mut(), - _thread_local: self, + thread_local: self, + raw: RawIter::new(), } } } @@ -361,22 +347,26 @@ impl fmt::Debug for ThreadLocal { impl UnwindSafe for ThreadLocal {} -/// Iterator over the contents of a `ThreadLocal`. #[derive(Debug)] -pub struct Iter<'a, T: Send + Sync> { - thread_local: &'a ThreadLocal, +struct RawIter { yielded: usize, bucket: usize, bucket_size: usize, index: usize, } +impl RawIter { + fn new() -> Self { + Self { + yielded: 0, + bucket: 0, + bucket_size: 1, + index: 0, + } + } -impl<'a, T: Send + Sync> Iterator for Iter<'a, T> { - type Item = &'a T; - - fn next(&mut self) -> Option { + fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal) -> Option<&'a T> { while self.bucket < BUCKETS { - let bucket = unsafe { self.thread_local.buckets.get_unchecked(self.bucket) }; + let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) }; let bucket = bucket.load(Ordering::Relaxed); if !bucket.is_null() { @@ -390,140 +380,128 @@ impl<'a, T: Send + Sync> Iterator for Iter<'a, T> { } } - if self.bucket != 0 { - self.bucket_size <<= 1; - } - self.bucket += 1; - - self.index = 0; + self.next_bucket(); } None } - - fn size_hint(&self) -> (usize, Option) { - let total = self.thread_local.values.load(Ordering::Acquire); - (total - self.yielded, None) - } -} -impl FusedIterator for Iter<'_, T> {} - -struct RawIterMut { - remaining: usize, - buckets: [*mut Entry; BUCKETS], - bucket: usize, - bucket_size: usize, - index: usize, -} - -impl Iterator for RawIterMut { - type Item = *mut MaybeUninit; - - fn next(&mut self) -> Option { - if self.remaining == 0 { + fn next_mut<'a, T: Send>( + &mut self, + thread_local: &'a mut ThreadLocal, + ) -> Option<&'a mut Entry> { + if *thread_local.values.get_mut() == self.yielded { return None; } loop { - let bucket = unsafe { *self.buckets.get_unchecked(self.bucket) }; + let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) }; + let bucket = *bucket.get_mut(); if !bucket.is_null() { while self.index < self.bucket_size { let entry = unsafe { &mut *bucket.add(self.index) }; self.index += 1; if *entry.present.get_mut() { - self.remaining -= 1; - return Some(entry.value.get()); + self.yielded += 1; + return Some(entry); } } } - if self.bucket != 0 { - self.bucket_size <<= 1; - } - self.bucket += 1; + self.next_bucket(); + } + } - self.index = 0; + fn next_bucket(&mut self) { + if self.bucket != 0 { + self.bucket_size <<= 1; } + self.bucket += 1; + self.index = 0; } - fn size_hint(&self) -> (usize, Option) { - (self.remaining, Some(self.remaining)) + fn size_hint(&self, thread_local: &ThreadLocal) -> (usize, Option) { + let total = thread_local.values.load(Ordering::Acquire); + (total - self.yielded, None) + } + fn size_hint_frozen(&self, thread_local: &ThreadLocal) -> (usize, Option) { + let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) }; + let remaining = total - self.yielded; + (remaining, Some(remaining)) } } -unsafe impl Send for RawIterMut {} -unsafe impl Sync for RawIterMut {} +/// Iterator over the contents of a `ThreadLocal`. +#[derive(Debug)] +pub struct Iter<'a, T: Send + Sync> { + thread_local: &'a ThreadLocal, + raw: RawIter, +} + +impl<'a, T: Send + Sync> Iterator for Iter<'a, T> { + type Item = &'a T; + fn next(&mut self) -> Option { + self.raw.next(self.thread_local) + } + fn size_hint(&self) -> (usize, Option) { + self.raw.size_hint(self.thread_local) + } +} +impl FusedIterator for Iter<'_, T> {} /// Mutable iterator over the contents of a `ThreadLocal`. pub struct IterMut<'a, T: Send> { - raw: RawIterMut, - marker: PhantomData<&'a mut ThreadLocal>, + thread_local: &'a mut ThreadLocal, + raw: RawIter, } impl<'a, T: Send> Iterator for IterMut<'a, T> { type Item = &'a mut T; - fn next(&mut self) -> Option<&'a mut T> { self.raw - .next() - .map(|x| unsafe { &mut *(&mut *x).as_mut_ptr() }) + .next_mut(self.thread_local) + .map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() }) } - fn size_hint(&self) -> (usize, Option) { - self.raw.size_hint() + self.raw.size_hint_frozen(self.thread_local) } } impl ExactSizeIterator for IterMut<'_, T> {} impl FusedIterator for IterMut<'_, T> {} -// The Debug bound is technically unnecessary but makes the API more consistent and future-proof. -impl fmt::Debug for IterMut<'_, T> { +// Manual impl so we don't call Debug on the ThreadLocal, as doing so would create a reference to +// this thread's value that potentially aliases with a mutable reference we have given out. +impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("IterMut") - .field("remaining", &self.raw.remaining) - .field("bucket", &self.raw.bucket) - .field("bucket_size", &self.raw.bucket_size) - .field("index", &self.raw.index) - .finish() + f.debug_struct("IterMut").field("raw", &self.raw).finish() } } /// An iterator that moves out of a `ThreadLocal`. +#[derive(Debug)] pub struct IntoIter { - raw: RawIterMut, - _thread_local: ThreadLocal, + thread_local: ThreadLocal, + raw: RawIter, } impl Iterator for IntoIter { type Item = T; - fn next(&mut self) -> Option { - self.raw - .next() - .map(|x| unsafe { std::mem::replace(&mut *x, MaybeUninit::uninit()).assume_init() }) + self.raw.next_mut(&mut self.thread_local).map(|entry| { + *entry.present.get_mut() = false; + unsafe { + std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init() + } + }) } - fn size_hint(&self) -> (usize, Option) { - self.raw.size_hint() + self.raw.size_hint_frozen(&self.thread_local) } } impl ExactSizeIterator for IntoIter {} impl FusedIterator for IntoIter {} -// The Debug bound is technically unnecessary but makes the API more consistent and future-proof. -impl fmt::Debug for IntoIter { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("IntoIter") - .field("remaining", &self.raw.remaining) - .field("bucket", &self.raw.bucket) - .field("bucket_size", &self.raw.bucket_size) - .field("index", &self.raw.index) - .finish() - } -} - fn allocate_bucket(size: usize) -> *mut Entry { Box::into_raw( (0..size)