diff --git a/src/libstd/bitv.rs b/src/libstd/bitv.rs index 91e9a1dc94062..901430a279b4f 100644 --- a/src/libstd/bitv.rs +++ b/src/libstd/bitv.rs @@ -19,37 +19,42 @@ export to_str; export eq_vec; export methods; +/// a mask that has a 1 for each defined bit in a small_bitv, assuming n bits +#[inline(always)] +fn small_mask(nbits: uint) -> u32 { + (1 << nbits) - 1 +} + struct small_bitv { + /// only the lowest nbits of this value are used. the rest is undefined. let mut bits: u32; new(bits: u32) { self.bits = bits; } priv { #[inline(always)] - fn bits_op(right_bits: u32, f: fn(u32, u32) -> u32) -> bool { + fn bits_op(right_bits: u32, nbits: uint, f: fn(u32, u32) -> u32) + -> bool { + let mask = small_mask(nbits); let old_b: u32 = self.bits; let new_b = f(old_b, right_bits); self.bits = new_b; - old_b != new_b + mask & old_b != mask & new_b } } #[inline(always)] - fn union(s: &small_bitv) -> bool { - self.bits_op(s.bits, |u1, u2| { u1 | u2 }) + fn union(s: &small_bitv, nbits: uint) -> bool { + self.bits_op(s.bits, nbits, |u1, u2| u1 | u2) } #[inline(always)] - fn intersect(s: &small_bitv) -> bool { - self.bits_op(s.bits, |u1, u2| { u1 & u2 }) + fn intersect(s: &small_bitv, nbits: uint) -> bool { + self.bits_op(s.bits, nbits, |u1, u2| u1 & u2) } #[inline(always)] - fn become(s: &small_bitv) -> bool { - let old = self.bits; - self.bits = s.bits; - old != self.bits + fn become(s: &small_bitv, nbits: uint) -> bool { + self.bits_op(s.bits, nbits, |_u1, u2| u2) } #[inline(always)] - fn difference(s: &small_bitv) -> bool { - let old = self.bits; - self.bits &= !s.bits; - old != self.bits + fn difference(s: &small_bitv, nbits: uint) -> bool { + self.bits_op(s.bits, nbits, |u1, u2| u1 ^ u2) } #[inline(always)] pure fn get(i: uint) -> bool { @@ -65,38 +70,66 @@ struct small_bitv { } } #[inline(always)] - fn equals(b: &small_bitv) -> bool { self.bits == b.bits } + fn equals(b: &small_bitv, nbits: uint) -> bool { + let mask = small_mask(nbits); + mask & self.bits == mask & b.bits + } #[inline(always)] fn clear() { self.bits = 0; } #[inline(always)] fn set_all() { self.bits = !0; } #[inline(always)] - fn is_true() -> bool { self.bits == !0 } + fn is_true(nbits: uint) -> bool { + small_mask(nbits) & !self.bits == 0 + } #[inline(always)] - fn is_false() -> bool { self.bits == 0 } + fn is_false(nbits: uint) -> bool { + small_mask(nbits) & self.bits == 0 + } #[inline(always)] fn invert() { self.bits = !self.bits; } } +/** + * a mask that has a 1 for each defined bit in the nth element of a big_bitv, + * assuming n bits. + */ +#[inline(always)] +fn big_mask(nbits: uint, elem: uint) -> uint { + let rmd = nbits % uint_bits; + let nelems = nbits/uint_bits + if rmd == 0 {0} else {1}; + + if elem < nelems - 1 || rmd == 0 { + !0 + } else { + (1 << rmd) - 1 + } +} + struct big_bitv { -// only mut b/c of clone and lack of other constructor + // only mut b/c of clone and lack of other constructor let mut storage: ~[mut uint]; new(-storage: ~[mut uint]) { self.storage <- storage; } priv { #[inline(always)] - fn process(b: &big_bitv, op: fn(uint, uint) -> uint) -> bool { + fn process(b: &big_bitv, nbits: uint, op: fn(uint, uint) -> uint) + -> bool { let len = b.storage.len(); assert (self.storage.len() == len); let mut changed = false; do uint::range(0, len) |i| { - let w0 = self.storage[i]; - let w1 = b.storage[i]; - let w = op(w0, w1); - if w0 != w unchecked { changed = true; self.storage[i] = w; }; + let mask = big_mask(nbits, i); + let w0 = self.storage[i] & mask; + let w1 = b.storage[i] & mask; + let w = op(w0, w1) & mask; + if w0 != w unchecked { + changed = true; + self.storage[i] = w; + } true - }; + } changed } } @@ -112,15 +145,21 @@ struct big_bitv { #[inline(always)] fn invert() { for self.each_storage() |w| { w = !w } } #[inline(always)] - fn union(b: &big_bitv) -> bool { self.process(b, lor) } + fn union(b: &big_bitv, nbits: uint) -> bool { + self.process(b, nbits, lor) + } #[inline(always)] - fn intersect(b: &big_bitv) -> bool { self.process(b, land) } + fn intersect(b: &big_bitv, nbits: uint) -> bool { + self.process(b, nbits, land) + } #[inline(always)] - fn become(b: &big_bitv) -> bool { self.process(b, right) } + fn become(b: &big_bitv, nbits: uint) -> bool { + self.process(b, nbits, right) + } #[inline(always)] - fn difference(b: &big_bitv) -> bool { + fn difference(b: &big_bitv, nbits: uint) -> bool { self.invert(); - let b = self.intersect(b); + let b = self.intersect(b, nbits); self.invert(); b } @@ -140,10 +179,13 @@ struct big_bitv { else { self.storage[w] & !flag }; } #[inline(always)] - fn equals(b: &big_bitv) -> bool { + fn equals(b: &big_bitv, nbits: uint) -> bool { let len = b.storage.len(); for uint::iterate(0, len) |i| { - if self.storage[i] != b.storage[i] { return false; } + let mask = big_mask(nbits, i); + if mask & self.storage[i] != mask & b.storage[i] { + return false; + } } } } @@ -163,8 +205,10 @@ struct bitv { self.rep = small(~small_bitv(if init {!0} else {0})); } else { - let s = to_mut(from_elem(nbits / uint_bits + 1, - if init {!0} else {0})); + let nelems = nbits/uint_bits + + if nbits % uint_bits == 0 {0} else {1}; + let elem = if init {!0} else {0}; + let s = to_mut(from_elem(nelems, elem)); self.rep = big(~big_bitv(s)); }; } @@ -182,20 +226,20 @@ struct bitv { match self.rep { small(s) => match other.rep { small(s1) => match op { - union => s.union(s1), - intersect => s.intersect(s1), - assign => s.become(s1), - difference => s.difference(s1) + union => s.union(s1, self.nbits), + intersect => s.intersect(s1, self.nbits), + assign => s.become(s1, self.nbits), + difference => s.difference(s1, self.nbits) }, big(s1) => self.die() }, big(s) => match other.rep { small(_) => self.die(), big(s1) => match op { - union => s.union(s1), - intersect => s.intersect(s1), - assign => s.become(s1), - difference => s.difference(s1) + union => s.union(s1, self.nbits), + intersect => s.intersect(s1, self.nbits), + assign => s.become(s1, self.nbits), + difference => s.difference(s1, self.nbits) } } } @@ -280,11 +324,11 @@ struct bitv { if self.nbits != v1.nbits { return false; } match self.rep { small(b) => match v1.rep { - small(b1) => b.equals(b1), + small(b1) => b.equals(b1, self.nbits), _ => false }, big(s) => match v1.rep { - big(s1) => s.equals(s1), + big(s1) => s.equals(s1, self.nbits), small(_) => return false } } @@ -330,7 +374,7 @@ struct bitv { #[inline(always)] fn is_true() -> bool { match self.rep { - small(b) => b.is_true(), + small(b) => b.is_true(self.nbits), _ => { for self.each() |i| { if !i { return false; } } true @@ -351,7 +395,7 @@ struct bitv { fn is_false() -> bool { match self.rep { - small(b) => b.is_false(), + small(b) => b.is_false(self.nbits), big(_) => { for self.each() |i| { if i { return false; } } true @@ -740,6 +784,33 @@ mod tests { let v1 = bitv(110u, false); assert !v0.equal(v1); } + + #[test] + fn test_equal_sneaky_small() { + let a = bitv::bitv(1, false); + a.set(0, true); + + let b = bitv::bitv(1, true); + b.set(0, true); + + assert a.equal(b); + } + + #[test] + fn test_equal_sneaky_big() { + let a = bitv::bitv(100, false); + for uint::range(0, 100) |i| { + a.set(i, true); + } + + let b = bitv::bitv(100, true); + for uint::range(0, 100) |i| { + b.set(i, true); + } + + assert a.equal(b); + } + } //