diff --git a/src/combinations.rs b/src/combinations.rs index 18adfc70e..c50e95d01 100644 --- a/src/combinations.rs +++ b/src/combinations.rs @@ -120,9 +120,95 @@ impl Iterator for Combinations // Create result vector based on the indices Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect()) } + + fn size_hint(&self) -> (usize, Option) { + let (mut low, mut upp) = self.pool.size_hint(); + low = remaining_for(low, self.first, &self.indices).unwrap_or(usize::MAX); + upp = upp.and_then(|upp| remaining_for(upp, self.first, &self.indices)); + (low, upp) + } + + fn count(self) -> usize { + let Self { indices, pool, first } = self; + // TODO: make `pool.it` private + let n = pool.len() + pool.it.count(); + remaining_for(n, first, &indices).unwrap() + } } impl FusedIterator for Combinations where I: Iterator, I::Item: Clone {} + +// https://en.wikipedia.org/wiki/Binomial_coefficient#In_programming_languages +pub(crate) fn checked_binomial(mut n: usize, mut k: usize) -> Option { + if n < k { + return Some(0); + } + // `factorial(n) / factorial(n - k) / factorial(k)` but trying to avoid it overflows: + k = (n - k).min(k); // symmetry + let mut c = 1; + for i in 1..=k { + c = (c / i).checked_mul(n)?.checked_add((c % i).checked_mul(n)? / i)?; + n -= 1; + } + Some(c) +} + +#[test] +fn test_checked_binomial() { + // With the first row: [1, 0, 0, ...] and the first column full of 1s, we check + // row by row the recurrence relation of binomials (which is an equivalent definition). + // For n >= 1 and k >= 1 we have: + // binomial(n, k) == binomial(n - 1, k - 1) + binomial(n - 1, k) + const LIMIT: usize = 500; + let mut row = vec![Some(0); LIMIT + 1]; + row[0] = Some(1); + for n in 0..=LIMIT { + for k in 0..=LIMIT { + assert_eq!(row[k], checked_binomial(n, k)); + } + row = std::iter::once(Some(1)) + .chain((1..=LIMIT).map(|k| row[k - 1]?.checked_add(row[k]?))) + .collect(); + } +} + +/// For a given size `n`, return the count of remaining combinations or None if it would overflow. +fn remaining_for(n: usize, first: bool, indices: &[usize]) -> Option { + let k = indices.len(); + if n < k { + Some(0) + } else if first { + checked_binomial(n, k) + } else { + // https://en.wikipedia.org/wiki/Combinatorial_number_system + // http://www.site.uottawa.ca/~lucia/courses/5165-09/GenCombObj.pdf + + // The combinations generated after the current one can be counted by counting as follows: + // - The subsequent combinations that differ in indices[0]: + // If subsequent combinations differ in indices[0], then their value for indices[0] + // must be at least 1 greater than the current indices[0]. + // As indices is strictly monotonically sorted, this means we can effectively choose k values + // from (n - 1 - indices[0]), leading to binomial(n - 1 - indices[0], k) possibilities. + // - The subsequent combinations with same indices[0], but differing indices[1]: + // Here we can choose k - 1 values from (n - 1 - indices[1]) values, + // leading to binomial(n - 1 - indices[1], k - 1) possibilities. + // - (...) + // - The subsequent combinations with same indices[0..=i], but differing indices[i]: + // Here we can choose k - i values from (n - 1 - indices[i]) values: binomial(n - 1 - indices[i], k - i). + // Since subsequent combinations can in any index, we must sum up the aforementioned binomial coefficients. + + // Below, `n0` resembles indices[i]. + indices + .iter() + .enumerate() + // TODO: Once the MSRV hits 1.37.0, we can sum options instead: + // .map(|(i, n0)| checked_binomial(n - 1 - *n0, k - i)) + // .sum() + .fold(Some(0), |sum, (i, n0)| { + sum.and_then(|s| s.checked_add(checked_binomial(n - 1 - *n0, k - i)?)) + }) + } +} diff --git a/src/lazy_buffer.rs b/src/lazy_buffer.rs index 65cc8e2cd..88ee06c7c 100644 --- a/src/lazy_buffer.rs +++ b/src/lazy_buffer.rs @@ -2,6 +2,8 @@ use std::iter::Fuse; use std::ops::Index; use alloc::vec::Vec; +use crate::size_hint::{self, SizeHint}; + #[derive(Debug, Clone)] pub struct LazyBuffer { pub it: Fuse, @@ -23,6 +25,10 @@ where self.buffer.len() } + pub fn size_hint(&self) -> SizeHint { + size_hint::add_scalar(self.it.size_hint(), self.len()) + } + pub fn get_next(&mut self) -> bool { if let Some(x) = self.it.next() { self.buffer.push(x); diff --git a/tests/test_std.rs b/tests/test_std.rs index 77207d87e..a351f4b3e 100644 --- a/tests/test_std.rs +++ b/tests/test_std.rs @@ -909,6 +909,28 @@ fn combinations_zero() { it::assert_equal((0..0).combinations(0), vec![vec![]]); } +#[test] +fn combinations_range_count() { + for n in 0..=10 { + for k in 0..=n { + let len = (n - k + 1..=n).product::() / (1..=k).product::(); + let mut it = (0..n).combinations(k); + assert_eq!(len, it.clone().count()); + assert_eq!(len, it.size_hint().0); + assert_eq!(Some(len), it.size_hint().1); + for count in (0..len).rev() { + let elem = it.next(); + assert!(elem.is_some()); + assert_eq!(count, it.clone().count()); + assert_eq!(count, it.size_hint().0); + assert_eq!(Some(count), it.size_hint().1); + } + let should_be_none = it.next(); + assert!(should_be_none.is_none()); + } + } +} + #[test] fn permutations_zero() { it::assert_equal((1..3).permutations(0), vec![vec![]]);