Skip to content

Commit

Permalink
Add Iter
Browse files Browse the repository at this point in the history
  • Loading branch information
Kestrer committed Jan 20, 2021
1 parent d9e93cb commit 1be15df
Showing 1 changed file with 114 additions and 34 deletions.
148 changes: 114 additions & 34 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -78,11 +78,12 @@ pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};

use std::cell::UnsafeCell;
use std::fmt;
use std::iter::FusedIterator;
use std::marker::PhantomData;
use std::mem;
use std::panic::UnwindSafe;
use std::ptr;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
use std::sync::Mutex;
use thread_id::Thread;
use unreachable::{UncheckedOptionExt, UncheckedResultExt};
@@ -104,13 +105,22 @@ const BUCKETS: usize = (POINTER_WIDTH + 1) as usize;
pub struct ThreadLocal<T: Send> {
/// The buckets in the thread local. The nth bucket contains `2^(n-1)`
/// elements. Each bucket is lazily allocated.
buckets: [AtomicPtr<UnsafeCell<Option<T>>>; BUCKETS],
buckets: [AtomicPtr<Entry<T>>; BUCKETS],

/// The number of values in the thread local. This can be less than the real number of values,
/// but is never more.
values: AtomicUsize,

/// Lock used to guard against concurrent modifications. This is taken when
/// there is a possibility of allocating a new bucket, which only occurs
/// when inserting values. This also guards the counter for the total number
/// of values in the thread local.
lock: Mutex<usize>,
/// when inserting values.
lock: Mutex<()>,
}

struct Entry<T> {
present: AtomicBool,
// Use MaybeUninit once the MSRV has been bumped.
value: UnsafeCell<Option<T>>,
}

// ThreadLocal is always Sync, even if T isn't
@@ -173,7 +183,8 @@ impl<T: Send> ThreadLocal<T> {
// Safety: AtomicPtr has the same representation as a pointer and arrays have the same
// representation as a sequence of their inner type.
buckets: unsafe { mem::transmute(buckets) },
lock: Mutex::new(0),
values: AtomicUsize::new(0),
lock: Mutex::new(()),
}
}

@@ -215,14 +226,13 @@ impl<T: Send> ThreadLocal<T> {
if bucket_ptr.is_null() {
return None;
}
unsafe { (&*(&*bucket_ptr.add(thread.index)).get()).as_ref() }
unsafe { (&*(*bucket_ptr.add(thread.index)).value.get()).as_ref() }
}

#[cold]
fn insert(&self, thread: Thread, data: T) -> &T {
// Lock the Mutex to ensure only a single thread is allocating buckets at once
let mut count = self.lock.lock().unwrap();
*count += 1;
let _guard = self.lock.lock().unwrap();

let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };

@@ -236,22 +246,40 @@ impl<T: Send> ThreadLocal<T> {
bucket_ptr
};

drop(count);
drop(_guard);

// Insert the new element into the bucket
unsafe {
let value_ptr = (&*bucket_ptr.add(thread.index)).get();
*value_ptr = Some(data);
(&*value_ptr).as_ref().unchecked_unwrap()
let entry = unsafe { &*bucket_ptr.add(thread.index) };
let value_ptr = entry.value.get();
unsafe { value_ptr.write(Some(data)) };
entry.present.store(true, Ordering::Release);

self.values.fetch_add(1, Ordering::Release);

unsafe { (&*value_ptr).as_ref().unchecked_unwrap() }
}

/// Returns an iterator over the local values of all threads in unspecified
/// order.
///
/// This call can be done safely, as `T` is required to implement [`Sync`].
pub fn iter(&self) -> Iter<'_, T>
where
T: Sync,
{
Iter {
thread_local: self,
yielded: 0,
bucket: 0,
bucket_size: 1,
index: 0,
}
}

fn raw_iter(&mut self) -> RawIter<T> {
RawIter {
remaining: *self.lock.get_mut().unwrap(),
buckets: unsafe {
*(&self.buckets as *const _ as *const [*const UnsafeCell<Option<T>>; BUCKETS])
},
fn raw_iter_mut(&mut self) -> RawIterMut<T> {
RawIterMut {
remaining: *self.values.get_mut(),
buckets: unsafe { *(&self.buckets as *const _ as *const [*mut Entry<T>; BUCKETS]) },
bucket: 0,
bucket_size: 1,
index: 0,
@@ -266,7 +294,7 @@ impl<T: Send> ThreadLocal<T> {
/// threads are currently accessing their associated values.
pub fn iter_mut(&mut self) -> IterMut<T> {
IterMut {
raw: self.raw_iter(),
raw: self.raw_iter_mut(),
marker: PhantomData,
}
}
@@ -288,7 +316,7 @@ impl<T: Send> IntoIterator for ThreadLocal<T> {

fn into_iter(mut self) -> IntoIter<T> {
IntoIter {
raw: self.raw_iter(),
raw: self.raw_iter_mut(),
_thread_local: self,
}
}
@@ -319,15 +347,60 @@ impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {

impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}

struct RawIter<T: Send> {
/// Iterator over the contents of a `ThreadLocal`.
pub struct Iter<'a, T: Send + Sync + 'a> {
thread_local: &'a ThreadLocal<T>,
yielded: usize,
bucket: usize,
bucket_size: usize,
index: usize,
}

impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
type Item = &'a T;

fn next(&mut self) -> Option<Self::Item> {
while self.bucket < BUCKETS {
let bucket = unsafe { self.thread_local.buckets.get_unchecked(self.bucket) };
let bucket = bucket.load(Ordering::Relaxed);

if !bucket.is_null() {
while self.index < self.bucket_size {
let entry = unsafe { &*bucket.add(self.index) };
self.index += 1;
if entry.present.load(Ordering::Acquire) {
self.yielded += 1;
return Some(unsafe { (&*entry.value.get()).as_ref().unchecked_unwrap() });
}
}
}

if self.bucket != 0 {
self.bucket_size <<= 1;
}
self.bucket += 1;

self.index = 0;
}
None
}

fn size_hint(&self) -> (usize, Option<usize>) {
let total = self.thread_local.values.load(Ordering::Acquire);
(total - self.yielded, None)
}
}
impl<'a, T: Send + Sync> FusedIterator for Iter<'a, T> {}

struct RawIterMut<T: Send> {
remaining: usize,
buckets: [*const UnsafeCell<Option<T>>; BUCKETS],
buckets: [*mut Entry<T>; BUCKETS],
bucket: usize,
bucket_size: usize,
index: usize,
}

impl<T: Send> Iterator for RawIter<T> {
impl<T: Send> Iterator for RawIterMut<T> {
type Item = *mut Option<T>;

fn next(&mut self) -> Option<Self::Item> {
@@ -340,13 +413,11 @@ impl<T: Send> Iterator for RawIter<T> {

if !bucket.is_null() {
while self.index < self.bucket_size {
let item = unsafe { (&*bucket.add(self.index)).get() };

let entry = unsafe { &mut *bucket.add(self.index) };
self.index += 1;

if unsafe { &*item }.is_some() {
if *entry.present.get_mut() {
self.remaining -= 1;
return Some(item);
return Some(entry.value.get());
}
}
}
@@ -367,7 +438,7 @@ impl<T: Send> Iterator for RawIter<T> {

/// Mutable iterator over the contents of a `ThreadLocal`.
pub struct IterMut<'a, T: Send + 'a> {
raw: RawIter<T>,
raw: RawIterMut<T>,
marker: PhantomData<&'a mut ThreadLocal<T>>,
}

@@ -389,7 +460,7 @@ impl<'a, T: Send + 'a> ExactSizeIterator for IterMut<'a, T> {}

/// An iterator that moves out of a `ThreadLocal`.
pub struct IntoIter<T: Send> {
raw: RawIter<T>,
raw: RawIterMut<T>,
_thread_local: ThreadLocal<T>,
}

@@ -409,10 +480,13 @@ impl<T: Send> Iterator for IntoIter<T> {

impl<T: Send> ExactSizeIterator for IntoIter<T> {}

fn allocate_bucket<T>(size: usize) -> *mut UnsafeCell<Option<T>> {
fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
Box::into_raw(
(0..size)
.map(|_| UnsafeCell::new(None::<T>))
.map(|_| Entry::<T> {
present: AtomicBool::new(false),
value: UnsafeCell::new(None),
})
.collect::<Vec<_>>()
.into_boxed_slice(),
) as *mut _
@@ -491,9 +565,15 @@ mod tests {
.unwrap();

let mut tls = Arc::try_unwrap(tls).unwrap();

let mut v = tls.iter().map(|x| **x).collect::<Vec<i32>>();
v.sort_unstable();
assert_eq!(vec![1, 2, 3], v);

let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
v.sort_unstable();
assert_eq!(vec![1, 2, 3], v);

let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
v.sort_unstable();
assert_eq!(vec![1, 2, 3], v);

0 comments on commit 1be15df

Please sign in to comment.