Skip to content

Commit

Permalink
Auto merge of rust-lang#32635 - gereeter:hashmap-iter-variance, r=ale…
Browse files Browse the repository at this point in the history
…xcrichton

Make HashMap, HashSet, and their iterators properly covariant

See rust-lang#30642. `Drain` is the only type left invariant.
  • Loading branch information
bors committed Apr 1, 2016
2 parents 3b342fa + 589108b commit 53498ec
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 44 deletions.
44 changes: 27 additions & 17 deletions src/libstd/collections/hash/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use borrow::Borrow;
use cmp::max;
use fmt::{self, Debug};
use hash::{Hash, SipHasher, BuildHasher};
use iter::{self, Map, FromIterator};
use iter::FromIterator;
use mem::{self, replace};
use ops::{Deref, Index};
use rand::{self, Rng};
Expand Down Expand Up @@ -836,8 +836,7 @@ impl<K, V, S> HashMap<K, V, S>
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn keys<'a>(&'a self) -> Keys<'a, K, V> {
fn first<A, B>((a, _): (A, B)) -> A { a }
Keys { inner: self.iter().map(first) }
Keys { inner: self.iter() }
}

/// An iterator visiting all values in arbitrary order.
Expand All @@ -859,8 +858,7 @@ impl<K, V, S> HashMap<K, V, S>
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn values<'a>(&'a self) -> Values<'a, K, V> {
fn second<A, B>((_, b): (A, B)) -> B { b }
Values { inner: self.iter().map(second) }
Values { inner: self.iter() }
}

/// An iterator visiting all key-value pairs in arbitrary order.
Expand Down Expand Up @@ -992,9 +990,8 @@ impl<K, V, S> HashMap<K, V, S>
#[inline]
#[stable(feature = "drain", since = "1.6.0")]
pub fn drain(&mut self) -> Drain<K, V> {
fn last_two<A, B, C>((_, b, c): (A, B, C)) -> (B, C) { (b, c) }
Drain {
inner: self.table.drain().map(last_two),
inner: self.table.drain(),
}
}

Expand Down Expand Up @@ -1224,13 +1221,13 @@ pub struct IterMut<'a, K: 'a, V: 'a> {
/// HashMap move iterator.
#[stable(feature = "rust1", since = "1.0.0")]
pub struct IntoIter<K, V> {
inner: iter::Map<table::IntoIter<K, V>, fn((SafeHash, K, V)) -> (K, V)>
inner: table::IntoIter<K, V>
}

/// HashMap keys iterator.
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Keys<'a, K: 'a, V: 'a> {
inner: Map<Iter<'a, K, V>, fn((&'a K, &'a V)) -> &'a K>
inner: Iter<'a, K, V>
}

// FIXME(#19839) Remove in favor of `#[derive(Clone)]`
Expand All @@ -1246,7 +1243,7 @@ impl<'a, K, V> Clone for Keys<'a, K, V> {
/// HashMap values iterator.
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Values<'a, K: 'a, V: 'a> {
inner: Map<Iter<'a, K, V>, fn((&'a K, &'a V)) -> &'a V>
inner: Iter<'a, K, V>
}

// FIXME(#19839) Remove in favor of `#[derive(Clone)]`
Expand All @@ -1262,7 +1259,7 @@ impl<'a, K, V> Clone for Values<'a, K, V> {
/// HashMap drain iterator.
#[stable(feature = "drain", since = "1.6.0")]
pub struct Drain<'a, K: 'a, V: 'a> {
inner: iter::Map<table::Drain<'a, K, V>, fn((SafeHash, K, V)) -> (K, V)>
inner: table::Drain<'a, K, V>
}

enum InternalEntry<K, V, M> {
Expand Down Expand Up @@ -1397,9 +1394,8 @@ impl<K, V, S> IntoIterator for HashMap<K, V, S>
/// let vec: Vec<(&str, isize)> = map.into_iter().collect();
/// ```
fn into_iter(self) -> IntoIter<K, V> {
fn last_two<A, B, C>((_, b, c): (A, B, C)) -> (B, C) { (b, c) }
IntoIter {
inner: self.table.into_iter().map(last_two)
inner: self.table.into_iter()
}
}
}
Expand Down Expand Up @@ -1432,7 +1428,7 @@ impl<'a, K, V> ExactSizeIterator for IterMut<'a, K, V> {
impl<K, V> Iterator for IntoIter<K, V> {
type Item = (K, V);

#[inline] fn next(&mut self) -> Option<(K, V)> { self.inner.next() }
#[inline] fn next(&mut self) -> Option<(K, V)> { self.inner.next().map(|(_, k, v)| (k, v)) }
#[inline] fn size_hint(&self) -> (usize, Option<usize>) { self.inner.size_hint() }
}
#[stable(feature = "rust1", since = "1.0.0")]
Expand All @@ -1444,7 +1440,7 @@ impl<K, V> ExactSizeIterator for IntoIter<K, V> {
impl<'a, K, V> Iterator for Keys<'a, K, V> {
type Item = &'a K;

#[inline] fn next(&mut self) -> Option<(&'a K)> { self.inner.next() }
#[inline] fn next(&mut self) -> Option<(&'a K)> { self.inner.next().map(|(k, _)| k) }
#[inline] fn size_hint(&self) -> (usize, Option<usize>) { self.inner.size_hint() }
}
#[stable(feature = "rust1", since = "1.0.0")]
Expand All @@ -1456,7 +1452,7 @@ impl<'a, K, V> ExactSizeIterator for Keys<'a, K, V> {
impl<'a, K, V> Iterator for Values<'a, K, V> {
type Item = &'a V;

#[inline] fn next(&mut self) -> Option<(&'a V)> { self.inner.next() }
#[inline] fn next(&mut self) -> Option<(&'a V)> { self.inner.next().map(|(_, v)| v) }
#[inline] fn size_hint(&self) -> (usize, Option<usize>) { self.inner.size_hint() }
}
#[stable(feature = "rust1", since = "1.0.0")]
Expand All @@ -1468,7 +1464,7 @@ impl<'a, K, V> ExactSizeIterator for Values<'a, K, V> {
impl<'a, K, V> Iterator for Drain<'a, K, V> {
type Item = (K, V);

#[inline] fn next(&mut self) -> Option<(K, V)> { self.inner.next() }
#[inline] fn next(&mut self) -> Option<(K, V)> { self.inner.next().map(|(_, k, v)| (k, v)) }
#[inline] fn size_hint(&self) -> (usize, Option<usize>) { self.inner.size_hint() }
}
#[stable(feature = "rust1", since = "1.0.0")]
Expand Down Expand Up @@ -1674,6 +1670,20 @@ impl<K, S, Q: ?Sized> super::Recover<Q> for HashMap<K, (), S>
}
}

#[allow(dead_code)]
fn assert_covariance() {
fn map_key<'new>(v: HashMap<&'static str, u8>) -> HashMap<&'new str, u8> { v }
fn map_val<'new>(v: HashMap<u8, &'static str>) -> HashMap<u8, &'new str> { v }
fn iter_key<'a, 'new>(v: Iter<'a, &'static str, u8>) -> Iter<'a, &'new str, u8> { v }
fn iter_val<'a, 'new>(v: Iter<'a, u8, &'static str>) -> Iter<'a, u8, &'new str> { v }
fn into_iter_key<'new>(v: IntoIter<&'static str, u8>) -> IntoIter<&'new str, u8> { v }
fn into_iter_val<'new>(v: IntoIter<u8, &'static str>) -> IntoIter<u8, &'new str> { v }
fn keys_key<'a, 'new>(v: Keys<'a, &'static str, u8>) -> Keys<'a, &'new str, u8> { v }
fn keys_val<'a, 'new>(v: Keys<'a, u8, &'static str>) -> Keys<'a, u8, &'new str> { v }
fn values_key<'a, 'new>(v: Values<'a, &'static str, u8>) -> Values<'a, &'new str, u8> { v }
fn values_val<'a, 'new>(v: Values<'a, u8, &'static str>) -> Values<'a, u8, &'new str> { v }
}

#[cfg(test)]
mod test_map {
use prelude::v1::*;
Expand Down
31 changes: 22 additions & 9 deletions src/libstd/collections/hash/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
use borrow::Borrow;
use fmt;
use hash::{Hash, BuildHasher};
use iter::{Map, Chain, FromIterator};
use iter::{Chain, FromIterator};
use ops::{BitOr, BitAnd, BitXor, Sub};

use super::Recover;
Expand Down Expand Up @@ -414,8 +414,7 @@ impl<T, S> HashSet<T, S>
#[inline]
#[stable(feature = "drain", since = "1.6.0")]
pub fn drain(&mut self) -> Drain<T> {
fn first<A, B>((a, _): (A, B)) -> A { a }
Drain { iter: self.map.drain().map(first) }
Drain { iter: self.map.drain() }
}

/// Clears the set, removing all values.
Expand Down Expand Up @@ -809,13 +808,13 @@ pub struct Iter<'a, K: 'a> {
/// HashSet move iterator
#[stable(feature = "rust1", since = "1.0.0")]
pub struct IntoIter<K> {
iter: Map<map::IntoIter<K, ()>, fn((K, ())) -> K>
iter: map::IntoIter<K, ()>
}

/// HashSet drain iterator
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Drain<'a, K: 'a> {
iter: Map<map::Drain<'a, K, ()>, fn((K, ())) -> K>,
iter: map::Drain<'a, K, ()>,
}

/// Intersection iterator
Expand Down Expand Up @@ -889,8 +888,7 @@ impl<T, S> IntoIterator for HashSet<T, S>
/// }
/// ```
fn into_iter(self) -> IntoIter<T> {
fn first<A, B>((a, _): (A, B)) -> A { a }
IntoIter { iter: self.map.into_iter().map(first) }
IntoIter { iter: self.map.into_iter() }
}
}

Expand All @@ -914,7 +912,7 @@ impl<'a, K> ExactSizeIterator for Iter<'a, K> {
impl<K> Iterator for IntoIter<K> {
type Item = K;

fn next(&mut self) -> Option<K> { self.iter.next() }
fn next(&mut self) -> Option<K> { self.iter.next().map(|(k, _)| k) }
fn size_hint(&self) -> (usize, Option<usize>) { self.iter.size_hint() }
}
#[stable(feature = "rust1", since = "1.0.0")]
Expand All @@ -926,7 +924,7 @@ impl<K> ExactSizeIterator for IntoIter<K> {
impl<'a, K> Iterator for Drain<'a, K> {
type Item = K;

fn next(&mut self) -> Option<K> { self.iter.next() }
fn next(&mut self) -> Option<K> { self.iter.next().map(|(k, _)| k) }
fn size_hint(&self) -> (usize, Option<usize>) { self.iter.size_hint() }
}
#[stable(feature = "rust1", since = "1.0.0")]
Expand Down Expand Up @@ -1026,6 +1024,21 @@ impl<'a, T, S> Iterator for Union<'a, T, S>
fn size_hint(&self) -> (usize, Option<usize>) { self.iter.size_hint() }
}

#[allow(dead_code)]
fn assert_covariance() {
fn set<'new>(v: HashSet<&'static str>) -> HashSet<&'new str> { v }
fn iter<'a, 'new>(v: Iter<'a, &'static str>) -> Iter<'a, &'new str> { v }
fn into_iter<'new>(v: IntoIter<&'static str>) -> IntoIter<&'new str> { v }
fn difference<'a, 'new>(v: Difference<'a, &'static str, RandomState>)
-> Difference<'a, &'new str, RandomState> { v }
fn symmetric_difference<'a, 'new>(v: SymmetricDifference<'a, &'static str, RandomState>)
-> SymmetricDifference<'a, &'new str, RandomState> { v }
fn intersection<'a, 'new>(v: Intersection<'a, &'static str, RandomState>)
-> Intersection<'a, &'new str, RandomState> { v }
fn union<'a, 'new>(v: Union<'a, &'static str, RandomState>)
-> Union<'a, &'new str, RandomState> { v }
}

#[cfg(test)]
mod test_set {
use prelude::v1::*;
Expand Down
41 changes: 23 additions & 18 deletions src/libstd/collections/hash/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ unsafe impl<K: Sync, V: Sync> Sync for RawTable<K, V> {}

struct RawBucket<K, V> {
hash: *mut u64,
key: *mut K,
val: *mut V,

// We use *const to ensure covariance with respect to K and V
key: *const K,
val: *const V,
_marker: marker::PhantomData<(K,V)>,
}

Expand Down Expand Up @@ -354,8 +356,8 @@ impl<K, V, M> EmptyBucket<K, V, M> where M: Put<K, V> {
-> FullBucket<K, V, M> {
unsafe {
*self.raw.hash = hash.inspect();
ptr::write(self.raw.key, key);
ptr::write(self.raw.val, value);
ptr::write(self.raw.key as *mut K, key);
ptr::write(self.raw.val as *mut V, value);

self.table.borrow_table_mut().size += 1;
}
Expand Down Expand Up @@ -453,8 +455,8 @@ impl<K, V, M> FullBucket<K, V, M> where M: Put<K, V> {
pub fn replace(&mut self, h: SafeHash, k: K, v: V) -> (SafeHash, K, V) {
unsafe {
let old_hash = ptr::replace(self.raw.hash as *mut SafeHash, h);
let old_key = ptr::replace(self.raw.key, k);
let old_val = ptr::replace(self.raw.val, v);
let old_key = ptr::replace(self.raw.key as *mut K, k);
let old_val = ptr::replace(self.raw.val as *mut V, v);

(old_hash, old_key, old_val)
}
Expand All @@ -465,8 +467,8 @@ impl<K, V, M> FullBucket<K, V, M> where M: Deref<Target=RawTable<K, V>> + DerefM
/// Gets mutable references to the key and value at a given index.
pub fn read_mut(&mut self) -> (&mut K, &mut V) {
unsafe {
(&mut *self.raw.key,
&mut *self.raw.val)
(&mut *(self.raw.key as *mut K),
&mut *(self.raw.val as *mut V))
}
}
}
Expand All @@ -490,8 +492,8 @@ impl<'t, K, V, M> FullBucket<K, V, M> where M: Deref<Target=RawTable<K, V>> + De
/// for mutable references into the table.
pub fn into_mut_refs(self) -> (&'t mut K, &'t mut V) {
unsafe {
(&mut *self.raw.key,
&mut *self.raw.val)
(&mut *(self.raw.key as *mut K),
&mut *(self.raw.val as *mut V))
}
}
}
Expand All @@ -505,8 +507,8 @@ impl<K, V, M> GapThenFull<K, V, M> where M: Deref<Target=RawTable<K, V>> {
pub fn shift(mut self) -> Option<GapThenFull<K, V, M>> {
unsafe {
*self.gap.raw.hash = mem::replace(&mut *self.full.raw.hash, EMPTY_BUCKET);
ptr::copy_nonoverlapping(self.full.raw.key, self.gap.raw.key, 1);
ptr::copy_nonoverlapping(self.full.raw.val, self.gap.raw.val, 1);
ptr::copy_nonoverlapping(self.full.raw.key, self.gap.raw.key as *mut K, 1);
ptr::copy_nonoverlapping(self.full.raw.val, self.gap.raw.val as *mut V, 1);
}

let FullBucket { raw: prev_raw, idx: prev_idx, .. } = self.full;
Expand Down Expand Up @@ -649,7 +651,7 @@ impl<K, V> RawTable<K, V> {
let hashes_size = self.capacity * size_of::<u64>();
let keys_size = self.capacity * size_of::<K>();

let buffer = *self.hashes as *mut u8;
let buffer = *self.hashes as *const u8;
let (keys_offset, vals_offset, oflo) =
calculate_offsets(hashes_size,
keys_size, align_of::<K>(),
Expand All @@ -658,8 +660,8 @@ impl<K, V> RawTable<K, V> {
unsafe {
RawBucket {
hash: *self.hashes,
key: buffer.offset(keys_offset as isize) as *mut K,
val: buffer.offset(vals_offset as isize) as *mut V,
key: buffer.offset(keys_offset as isize) as *const K,
val: buffer.offset(vals_offset as isize) as *const V,
_marker: marker::PhantomData,
}
}
Expand Down Expand Up @@ -707,6 +709,7 @@ impl<K, V> RawTable<K, V> {
IterMut {
iter: self.raw_buckets(),
elems_left: self.size(),
_marker: marker::PhantomData,
}
}

Expand Down Expand Up @@ -858,6 +861,8 @@ impl<'a, K, V> Clone for Iter<'a, K, V> {
pub struct IterMut<'a, K: 'a, V: 'a> {
iter: RawBuckets<'a, K, V>,
elems_left: usize,
// To ensure invariance with respect to V
_marker: marker::PhantomData<&'a mut V>,
}

unsafe impl<'a, K: Sync, V: Sync> Sync for IterMut<'a, K, V> {}
Expand Down Expand Up @@ -912,7 +917,7 @@ impl<'a, K, V> Iterator for IterMut<'a, K, V> {
self.elems_left -= 1;
unsafe {
(&*bucket.key,
&mut *bucket.val)
&mut *(bucket.val as *mut V))
}
})
}
Expand Down Expand Up @@ -1003,8 +1008,8 @@ impl<K: Clone, V: Clone> Clone for RawTable<K, V> {
(full.hash(), k.clone(), v.clone())
};
*new_buckets.raw.hash = h.inspect();
ptr::write(new_buckets.raw.key, k);
ptr::write(new_buckets.raw.val, v);
ptr::write(new_buckets.raw.key as *mut K, k);
ptr::write(new_buckets.raw.val as *mut V, v);
}
Empty(..) => {
*new_buckets.raw.hash = EMPTY_BUCKET;
Expand Down

0 comments on commit 53498ec

Please sign in to comment.