Skip to content

Commit

Permalink
Cache conscious hashmap table
Browse files Browse the repository at this point in the history
  • Loading branch information
arthurprs committed Oct 12, 2016
1 parent a7bfb1a commit c435821
Showing 1 changed file with 68 additions and 88 deletions.
156 changes: 68 additions & 88 deletions src/libstd/collections/hash/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ use self::BucketState::*;
const EMPTY_BUCKET: u64 = 0;

/// The raw hashtable, providing safe-ish access to the unzipped and highly
/// optimized arrays of hashes, keys, and values.
/// optimized arrays of hashes, and key-value pairs.
///
/// This design uses less memory and is a lot faster than the naive
/// `Vec<Option<u64, K, V>>`, because we don't pay for the overhead of an
/// This design is a lot faster than the naive
/// `Vec<Option<(u64, K, V)>>`, because we don't pay for the overhead of an
/// option on every element, and we get a generally more cache-aware design.
///
/// Essential invariants of this structure:
Expand All @@ -48,17 +48,19 @@ const EMPTY_BUCKET: u64 = 0;
/// which will likely map to the same bucket, while not being confused
/// with "empty".
///
/// - All three "arrays represented by pointers" are the same length:
/// - Both "arrays represented by pointers" are the same length:
/// `capacity`. This is set at creation and never changes. The arrays
/// are unzipped to save space (we don't have to pay for the padding
/// between odd sized elements, such as in a map from u64 to u8), and
/// be more cache aware (scanning through 8 hashes brings in at most
/// 2 cache lines, since they're all right beside each other).
/// are unzipped and are more cache aware (scanning through 8 hashes
/// brings in at most 2 cache lines, since they're all right beside each
/// other). This layout may waste space in padding such as in a map from
/// u64 to u8, but is a more cache conscious layout as the key-value pairs
/// are only very shortly probed and the desired value will be in the same
/// or next cache line.
///
/// You can kind of think of this module/data structure as a safe wrapper
/// around just the "table" part of the hashtable. It enforces some
/// invariants at the type level and employs some performance trickery,
/// but in general is just a tricked out `Vec<Option<u64, K, V>>`.
/// but in general is just a tricked out `Vec<Option<(u64, K, V)>>`.
pub struct RawTable<K, V> {
capacity: usize,
size: usize,
Expand All @@ -74,10 +76,8 @@ unsafe impl<K: Sync, V: Sync> Sync for RawTable<K, V> {}

struct RawBucket<K, V> {
hash: *mut u64,

// We use *const to ensure covariance with respect to K and V
key: *const K,
val: *const V,
pair: *const (K, V),
_marker: marker::PhantomData<(K, V)>,
}

Expand Down Expand Up @@ -181,8 +181,7 @@ impl<K, V> RawBucket<K, V> {
unsafe fn offset(self, count: isize) -> RawBucket<K, V> {
RawBucket {
hash: self.hash.offset(count),
key: self.key.offset(count),
val: self.val.offset(count),
pair: self.pair.offset(count),
_marker: marker::PhantomData,
}
}
Expand Down Expand Up @@ -370,8 +369,7 @@ impl<K, V, M> EmptyBucket<K, V, M>
pub fn put(mut self, hash: SafeHash, key: K, value: V) -> FullBucket<K, V, M> {
unsafe {
*self.raw.hash = hash.inspect();
ptr::write(self.raw.key as *mut K, key);
ptr::write(self.raw.val as *mut V, value);
ptr::write(self.raw.pair as *mut (K, V), (key, value));

self.table.borrow_table_mut().size += 1;
}
Expand Down Expand Up @@ -430,7 +428,7 @@ impl<K, V, M: Deref<Target = RawTable<K, V>>> FullBucket<K, V, M> {

/// Gets references to the key and value at a given index.
pub fn read(&self) -> (&K, &V) {
unsafe { (&*self.raw.key, &*self.raw.val) }
unsafe { (&(*self.raw.pair).0, &(*self.raw.pair).1) }
}
}

Expand All @@ -447,13 +445,14 @@ impl<'t, K, V> FullBucket<K, V, &'t mut RawTable<K, V>> {

unsafe {
*self.raw.hash = EMPTY_BUCKET;
let (k, v) = ptr::read(self.raw.pair);
(EmptyBucket {
raw: self.raw,
idx: self.idx,
table: self.table,
},
ptr::read(self.raw.key),
ptr::read(self.raw.val))
k,
v)
}
}
}
Expand All @@ -466,8 +465,7 @@ impl<K, V, M> FullBucket<K, V, M>
pub fn replace(&mut self, h: SafeHash, k: K, v: V) -> (SafeHash, K, V) {
unsafe {
let old_hash = ptr::replace(self.raw.hash as *mut SafeHash, h);
let old_key = ptr::replace(self.raw.key as *mut K, k);
let old_val = ptr::replace(self.raw.val as *mut V, v);
let (old_key, old_val) = ptr::replace(self.raw.pair as *mut (K, V), (k, v));

(old_hash, old_key, old_val)
}
Expand All @@ -479,7 +477,8 @@ impl<K, V, M> FullBucket<K, V, M>
{
/// Gets mutable references to the key and value at a given index.
pub fn read_mut(&mut self) -> (&mut K, &mut V) {
unsafe { (&mut *(self.raw.key as *mut K), &mut *(self.raw.val as *mut V)) }
let pair_mut = self.raw.pair as *mut (K, V);
unsafe { (&mut (*pair_mut).0, &mut (*pair_mut).1) }
}
}

Expand All @@ -492,7 +491,7 @@ impl<'t, K, V, M> FullBucket<K, V, M>
/// in exchange for this, the returned references have a longer lifetime
/// than the references returned by `read()`.
pub fn into_refs(self) -> (&'t K, &'t V) {
unsafe { (&*self.raw.key, &*self.raw.val) }
unsafe { (&(*self.raw.pair).0, &(*self.raw.pair).1) }
}
}

Expand All @@ -502,7 +501,8 @@ impl<'t, K, V, M> FullBucket<K, V, M>
/// This works similarly to `into_refs`, exchanging a bucket state
/// for mutable references into the table.
pub fn into_mut_refs(self) -> (&'t mut K, &'t mut V) {
unsafe { (&mut *(self.raw.key as *mut K), &mut *(self.raw.val as *mut V)) }
let pair_mut = self.raw.pair as *mut (K, V);
unsafe { (&mut (*pair_mut).0, &mut (*pair_mut).1) }
}
}

Expand All @@ -517,8 +517,7 @@ impl<K, V, M> GapThenFull<K, V, M>
pub fn shift(mut self) -> Option<GapThenFull<K, V, M>> {
unsafe {
*self.gap.raw.hash = mem::replace(&mut *self.full.raw.hash, EMPTY_BUCKET);
ptr::copy_nonoverlapping(self.full.raw.key, self.gap.raw.key as *mut K, 1);
ptr::copy_nonoverlapping(self.full.raw.val, self.gap.raw.val as *mut V, 1);
ptr::copy_nonoverlapping(self.full.raw.pair, self.gap.raw.pair as *mut (K, V), 1);
}

let FullBucket { raw: prev_raw, idx: prev_idx, .. } = self.full;
Expand Down Expand Up @@ -560,49 +559,42 @@ fn test_rounding() {
assert_eq!(round_up_to_next(5, 4), 8);
}

// Returns a tuple of (key_offset, val_offset),
// Returns a tuple of (pairs_offset, end_of_pairs_offset),
// from the start of a mallocated array.
#[inline]
fn calculate_offsets(hashes_size: usize,
keys_size: usize,
keys_align: usize,
vals_align: usize)
pairs_size: usize,
pairs_align: usize)
-> (usize, usize, bool) {
let keys_offset = round_up_to_next(hashes_size, keys_align);
let (end_of_keys, oflo) = keys_offset.overflowing_add(keys_size);

let vals_offset = round_up_to_next(end_of_keys, vals_align);
let pairs_offset = round_up_to_next(hashes_size, pairs_align);
let (end_of_pairs, oflo) = pairs_offset.overflowing_add(pairs_size);

(keys_offset, vals_offset, oflo)
(pairs_offset, end_of_pairs, oflo)
}

// Returns a tuple of (minimum required malloc alignment, hash_offset,
// array_size), from the start of a mallocated array.
fn calculate_allocation(hash_size: usize,
hash_align: usize,
keys_size: usize,
keys_align: usize,
vals_size: usize,
vals_align: usize)
pairs_size: usize,
pairs_align: usize)
-> (usize, usize, usize, bool) {
let hash_offset = 0;
let (_, vals_offset, oflo) = calculate_offsets(hash_size, keys_size, keys_align, vals_align);
let (end_of_vals, oflo2) = vals_offset.overflowing_add(vals_size);
let (_, end_of_pairs, oflo) = calculate_offsets(hash_size, pairs_size, pairs_align);

let align = cmp::max(hash_align, cmp::max(keys_align, vals_align));
let align = cmp::max(hash_align, pairs_align);

(align, hash_offset, end_of_vals, oflo || oflo2)
(align, hash_offset, end_of_pairs, oflo)
}

#[test]
fn test_offset_calculation() {
assert_eq!(calculate_allocation(128, 8, 15, 1, 4, 4),
(8, 0, 148, false));
assert_eq!(calculate_allocation(3, 1, 2, 1, 1, 1), (1, 0, 6, false));
assert_eq!(calculate_allocation(6, 2, 12, 4, 24, 8), (8, 0, 48, false));
assert_eq!(calculate_offsets(128, 15, 1, 4), (128, 144, false));
assert_eq!(calculate_offsets(3, 2, 1, 1), (3, 5, false));
assert_eq!(calculate_offsets(6, 12, 4, 8), (8, 24, false));
assert_eq!(calculate_allocation(128, 8, 16, 8), (8, 0, 144, false));
assert_eq!(calculate_allocation(3, 1, 2, 1), (1, 0, 5, false));
assert_eq!(calculate_allocation(6, 2, 12, 4), (4, 0, 20, false));
assert_eq!(calculate_offsets(128, 15, 4), (128, 143, false));
assert_eq!(calculate_offsets(3, 2, 4), (4, 6, false));
assert_eq!(calculate_offsets(6, 12, 4), (8, 20, false));
}

impl<K, V> RawTable<K, V> {
Expand All @@ -620,39 +612,31 @@ impl<K, V> RawTable<K, V> {

// No need for `checked_mul` before a more restrictive check performed
// later in this method.
let hashes_size = capacity * size_of::<u64>();
let keys_size = capacity * size_of::<K>();
let vals_size = capacity * size_of::<V>();
let hashes_size = capacity.wrapping_mul(size_of::<u64>());
let pairs_size = capacity.wrapping_mul(size_of::<(K, V)>());

// Allocating hashmaps is a little tricky. We need to allocate three
// Allocating hashmaps is a little tricky. We need to allocate two
// arrays, but since we know their sizes and alignments up front,
// we just allocate a single array, and then have the subarrays
// point into it.
//
// This is great in theory, but in practice getting the alignment
// right is a little subtle. Therefore, calculating offsets has been
// factored out into a different function.
let (malloc_alignment, hash_offset, size, oflo) = calculate_allocation(hashes_size,
align_of::<u64>(),
keys_size,
align_of::<K>(),
vals_size,
align_of::<V>());

let (alignment, hash_offset, size, oflo) = calculate_allocation(hashes_size,
align_of::<u64>(),
pairs_size,
align_of::<(K, V)>());
assert!(!oflo, "capacity overflow");

// One check for overflow that covers calculation and rounding of size.
let size_of_bucket = size_of::<u64>()
.checked_add(size_of::<K>())
.unwrap()
.checked_add(size_of::<V>())
.unwrap();
let size_of_bucket = size_of::<u64>().checked_add(size_of::<(K, V)>()).unwrap();
assert!(size >=
capacity.checked_mul(size_of_bucket)
.expect("capacity overflow"),
"capacity overflow");

let buffer = allocate(size, malloc_alignment);
let buffer = allocate(size, alignment);
if buffer.is_null() {
::alloc::oom()
}
Expand All @@ -669,17 +653,16 @@ impl<K, V> RawTable<K, V> {

fn first_bucket_raw(&self) -> RawBucket<K, V> {
let hashes_size = self.capacity * size_of::<u64>();
let keys_size = self.capacity * size_of::<K>();
let pairs_size = self.capacity * size_of::<(K, V)>();

let buffer = *self.hashes as *const u8;
let (keys_offset, vals_offset, oflo) =
calculate_offsets(hashes_size, keys_size, align_of::<K>(), align_of::<V>());
let buffer = *self.hashes as *mut u8;
let (pairs_offset, _, oflo) =
calculate_offsets(hashes_size, pairs_size, align_of::<(K, V)>());
debug_assert!(!oflo, "capacity overflow");
unsafe {
RawBucket {
hash: *self.hashes,
key: buffer.offset(keys_offset as isize) as *const K,
val: buffer.offset(vals_offset as isize) as *const V,
pair: buffer.offset(pairs_offset as isize) as *const _,
_marker: marker::PhantomData,
}
}
Expand Down Expand Up @@ -844,7 +827,7 @@ impl<'a, K, V> Iterator for RevMoveBuckets<'a, K, V> {

if *self.raw.hash != EMPTY_BUCKET {
self.elems_left -= 1;
return Some((ptr::read(self.raw.key), ptr::read(self.raw.val)));
return Some(ptr::read(self.raw.pair));
}
}
}
Expand Down Expand Up @@ -909,7 +892,7 @@ impl<'a, K, V> Iterator for Iter<'a, K, V> {
fn next(&mut self) -> Option<(&'a K, &'a V)> {
self.iter.next().map(|bucket| {
self.elems_left -= 1;
unsafe { (&*bucket.key, &*bucket.val) }
unsafe { (&(*bucket.pair).0, &(*bucket.pair).1) }
})
}

Expand All @@ -929,7 +912,8 @@ impl<'a, K, V> Iterator for IterMut<'a, K, V> {
fn next(&mut self) -> Option<(&'a K, &'a mut V)> {
self.iter.next().map(|bucket| {
self.elems_left -= 1;
unsafe { (&*bucket.key, &mut *(bucket.val as *mut V)) }
let pair_mut = bucket.pair as *mut (K, V);
unsafe { (&(*pair_mut).0, &mut (*pair_mut).1) }
})
}

Expand All @@ -950,7 +934,8 @@ impl<K, V> Iterator for IntoIter<K, V> {
self.iter.next().map(|bucket| {
self.table.size -= 1;
unsafe {
(SafeHash { hash: *bucket.hash }, ptr::read(bucket.key), ptr::read(bucket.val))
let (k, v) = ptr::read(bucket.pair);
(SafeHash { hash: *bucket.hash }, k, v)
}
})
}
Expand All @@ -974,9 +959,8 @@ impl<'a, K, V> Iterator for Drain<'a, K, V> {
self.iter.next().map(|bucket| {
unsafe {
(**self.table).size -= 1;
(SafeHash { hash: ptr::replace(bucket.hash, EMPTY_BUCKET) },
ptr::read(bucket.key),
ptr::read(bucket.val))
let (k, v) = ptr::read(bucket.pair);
(SafeHash { hash: ptr::replace(bucket.hash, EMPTY_BUCKET) }, k, v)
}
})
}
Expand Down Expand Up @@ -1015,8 +999,7 @@ impl<K: Clone, V: Clone> Clone for RawTable<K, V> {
(full.hash(), k.clone(), v.clone())
};
*new_buckets.raw.hash = h.inspect();
ptr::write(new_buckets.raw.key as *mut K, k);
ptr::write(new_buckets.raw.val as *mut V, v);
ptr::write(new_buckets.raw.pair as *mut (K, V), (k, v));
}
Empty(..) => {
*new_buckets.raw.hash = EMPTY_BUCKET;
Expand Down Expand Up @@ -1054,14 +1037,11 @@ impl<K, V> Drop for RawTable<K, V> {
}

let hashes_size = self.capacity * size_of::<u64>();
let keys_size = self.capacity * size_of::<K>();
let vals_size = self.capacity * size_of::<V>();
let pairs_size = self.capacity * size_of::<(K, V)>();
let (align, _, size, oflo) = calculate_allocation(hashes_size,
align_of::<u64>(),
keys_size,
align_of::<K>(),
vals_size,
align_of::<V>());
pairs_size,
align_of::<(K, V)>());

debug_assert!(!oflo, "should be impossible");

Expand Down

0 comments on commit c435821

Please sign in to comment.