Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Permutations #790

Merged
257 changes: 97 additions & 160 deletions src/permutations.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use alloc::vec::Vec;
use std::fmt;
use std::iter::once;
use std::iter::FusedIterator;

use super::lazy_buffer::LazyBuffer;
use crate::size_hint::{self, SizeHint};
Expand All @@ -26,22 +27,17 @@ where

#[derive(Clone, Debug)]
enum PermutationState {
StartUnknownLen { k: usize },
OngoingUnknownLen { k: usize, min_n: usize },
Complete(CompleteState),
Empty,
}

#[derive(Clone, Debug)]
enum CompleteState {
Start {
n: usize,
k: usize,
},
Ongoing {
/// No permutation generated yet.
Start { k: usize },
/// Values from the iterator are not fully loaded yet so `n` is still unknown.
Buffered { k: usize, min_n: usize },
/// All values from the iterator are known so `n` is known.
Loaded {
indices: Vec<usize>,
cycles: Vec<usize>,
},
/// No permutation left to generate.
End,
}

impl<I> fmt::Debug for Permutations<I>
Expand All @@ -55,20 +51,13 @@ where
pub fn permutations<I: Iterator>(iter: I, k: usize) -> Permutations<I> {
let mut vals = LazyBuffer::new(iter);

if k == 0 {
// Special case, yields single empty vec; `n` is irrelevant
let state = PermutationState::Complete(CompleteState::Start { n: 0, k: 0 });

return Permutations { vals, state };
}

vals.prefill(k);
let enough_vals = vals.len() == k;

let state = if enough_vals {
PermutationState::StartUnknownLen { k }
PermutationState::Start { k }
} else {
PermutationState::Empty
PermutationState::End
};

Permutations { vals, state }
Expand All @@ -82,169 +71,117 @@ where
type Item = Vec<I::Item>;

fn next(&mut self) -> Option<Self::Item> {
{
let &mut Permutations {
ref mut vals,
ref mut state,
} = self;
match *state {
PermutationState::StartUnknownLen { k } => {
*state = PermutationState::OngoingUnknownLen { k, min_n: k };
}
PermutationState::OngoingUnknownLen { k, min_n } => {
if vals.get_next() {
*state = PermutationState::OngoingUnknownLen {
k,
min_n: min_n + 1,
};
} else {
let n = min_n;
let prev_iteration_count = n - k + 1;
let mut complete_state = CompleteState::Start { n, k };

// Advance the complete-state iterator to the correct point
for _ in 0..(prev_iteration_count + 1) {
complete_state.advance();
let Self { vals, state } = self;
match state {
PermutationState::Start { k: 0 } => {
*state = PermutationState::End;
Some(Vec::new())
}
&mut PermutationState::Start { k } => {
*state = PermutationState::Buffered { k, min_n: k };
Some(vals[0..k].to_vec())
}
PermutationState::Buffered { ref k, min_n } => {
if vals.get_next() {
let item = (0..*k - 1)
.chain(once(*min_n))
.map(|i| vals[i].clone())
.collect();
*min_n += 1;
Some(item)
} else {
let n = *min_n;
let prev_iteration_count = n - *k + 1;
let mut indices: Vec<_> = (0..n).collect();
let mut cycles: Vec<_> = (n - k..n).rev().collect();
// Advance the state to the correct point.
for _ in 0..prev_iteration_count {
if advance(&mut indices, &mut cycles) {
*state = PermutationState::End;
return None;
}

*state = PermutationState::Complete(complete_state);
}
let item = indices[0..*k].iter().map(|&i| vals[i].clone()).collect();
*state = PermutationState::Loaded { indices, cycles };
Some(item)
}
PermutationState::Complete(ref mut state) => {
state.advance();
}
PermutationState::Empty => {}
};
}
let &mut Permutations {
ref vals,
ref state,
} = self;
match *state {
PermutationState::StartUnknownLen { .. } => panic!("unexpected iterator state"),
PermutationState::OngoingUnknownLen { k, min_n } => {
let latest_idx = min_n - 1;
let indices = (0..(k - 1)).chain(once(latest_idx));

Some(indices.map(|i| vals[i].clone()).collect())
}
PermutationState::Complete(CompleteState::Ongoing {
ref indices,
ref cycles,
}) => {
PermutationState::Loaded { indices, cycles } => {
if advance(indices, cycles) {
*state = PermutationState::End;
return None;
}
let k = cycles.len();
Some(indices[0..k].iter().map(|&i| vals[i].clone()).collect())
}
PermutationState::Complete(CompleteState::Start { .. }) | PermutationState::Empty => {
None
}
PermutationState::End => None,
}
}

fn count(self) -> usize {
fn from_complete(complete_state: CompleteState) -> usize {
complete_state
.remaining()
.expect("Iterator count greater than usize::MAX")
}

let Permutations { vals, state } = self;
match state {
PermutationState::StartUnknownLen { k } => {
let n = vals.count();
let complete_state = CompleteState::Start { n, k };

from_complete(complete_state)
}
PermutationState::OngoingUnknownLen { k, min_n } => {
let prev_iteration_count = min_n - k + 1;
let n = vals.count();
let complete_state = CompleteState::Start { n, k };

from_complete(complete_state) - prev_iteration_count
}
PermutationState::Complete(state) => from_complete(state),
PermutationState::Empty => 0,
}
let Self { vals, state } = self;
let n = vals.count();
state.size_hint_for(n).1.unwrap()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we go via SizeHint (i.e. (usize, Option<usize>)) to compute count. Fine given the simplification we gain by this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, size_hint_for mostly ends with (x.unwrap_or(usize::MAX), x) on which we either do .0 or .1.
At worst, it needlessly unwrapped an option one time.

}

fn size_hint(&self) -> SizeHint {
let at_start = |k| {
// At the beginning, there are `n!/(n-k)!` items to come (see `remaining`) but `n` might be unknown.
let (mut low, mut upp) = self.vals.size_hint();
low = CompleteState::Start { n: low, k }
.remaining()
.unwrap_or(usize::MAX);
upp = upp.and_then(|n| CompleteState::Start { n, k }.remaining());
(low, upp)
};
match self.state {
PermutationState::StartUnknownLen { k } => at_start(k),
PermutationState::OngoingUnknownLen { k, min_n } => {
// Same as `StartUnknownLen` minus the previously generated items.
size_hint::sub_scalar(at_start(k), min_n - k + 1)
}
PermutationState::Complete(ref state) => match state.remaining() {
Some(count) => (count, Some(count)),
None => (::std::usize::MAX, None),
},
PermutationState::Empty => (0, Some(0)),
}
let (mut low, mut upp) = self.vals.size_hint();
low = self.state.size_hint_for(low).0;
upp = upp.and_then(|n| self.state.size_hint_for(n).1);
Comment on lines +129 to +130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern seems familiar to me... Do you know if it occurs somewhere else? If so, should we introduce size_hint::map? (We can do this separately, it just occured to me.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern occurs most of the time I think. I thought of a similar method in a messy PR but I had headaches about the conditions on f for the resulting size hint to be correct, so as I wrote several size hints I went with applying the pattern manually.

(low, upp)
}
}

impl CompleteState {
fn advance(&mut self) {
*self = match *self {
CompleteState::Start { n, k } => {
let indices = (0..n).collect();
let cycles = ((n - k)..n).rev().collect();

CompleteState::Ongoing { cycles, indices }
}
CompleteState::Ongoing {
ref mut indices,
ref mut cycles,
} => {
let n = indices.len();
let k = cycles.len();

for i in (0..k).rev() {
if cycles[i] == 0 {
cycles[i] = n - i - 1;

let to_push = indices.remove(i);
indices.push(to_push);
} else {
let swap_index = n - cycles[i];
indices.swap(i, swap_index);

cycles[i] -= 1;
return;
}
}
impl<I> FusedIterator for Permutations<I>
where
I: Iterator,
I::Item: Clone,
{
}

CompleteState::Start { n, k }
}
fn advance(indices: &mut [usize], cycles: &mut [usize]) -> bool {
let n = indices.len();
let k = cycles.len();
// NOTE: if `cycles` are only zeros, then we reached the last permutation.
for i in (0..k).rev() {
if cycles[i] == 0 {
cycles[i] = n - i - 1;
indices[i..].rotate_left(1);
} else {
let swap_index = n - cycles[i];
indices.swap(i, swap_index);
cycles[i] -= 1;
return false;
}
}
true
}

/// Returns the count of remaining permutations, or None if it would overflow.
fn remaining(&self) -> Option<usize> {
impl PermutationState {
fn size_hint_for(&self, n: usize) -> SizeHint {
// At the beginning, there are `n!/(n-k)!` items to come.
let at_start = |n, k| {
debug_assert!(n >= k);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you shortly explain why this debug_assert holds?

Copy link
Member Author

@Philippe-Cholet Philippe-Cholet Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote that in my previous branch some weeks ago so I was not much familiar with it, and take a glance at it is definitely not enough as it requires a bit of thinking.

This debug_assert! only occurs with Start and Buffered variants.
At definition, we prefill the lazy buffer with k values. It has enough values (or we would have the End variant) so vals.len() >= k (vals.len()==k at definition, more later).
size_hint_for is then called with:

  • in the case of count: n = vals.count() >= vals.len() (see lazy buffer for >=) ;
  • in the case of size_hint: n = vals.size_hint().0 >= vals.len() (see lazy buffer).
    Similar for n = vals.size_hint().1.

So in each case: n >= vals.len() >= k. Basically, it holds because we prefilled with k values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I'm considering to soon work on making all our iterators lazy (such as #602) and I'll surely turn that assertion into if n < k { return (0, Some(0)); } (and move the "prefill the lazy buffer" part).

let total = (n - k + 1..=n).try_fold(1usize, |acc, i| acc.checked_mul(i));
(total.unwrap_or(usize::MAX), total)
};
match *self {
CompleteState::Start { n, k } => {
if n < k {
return Some(0);
}
(n - k + 1..=n).try_fold(1usize, |acc, i| acc.checked_mul(i))
Self::Start { k } => at_start(n, k),
Self::Buffered { k, min_n } => {
// Same as `Start` minus the previously generated items.
size_hint::sub_scalar(at_start(n, k), min_n - k + 1)
}
CompleteState::Ongoing {
Self::Loaded {
ref indices,
ref cycles,
} => cycles.iter().enumerate().try_fold(0usize, |acc, (i, &c)| {
acc.checked_mul(indices.len() - i)
.and_then(|count| count.checked_add(c))
}),
} => {
let count = cycles.iter().enumerate().try_fold(0usize, |acc, (i, &c)| {
acc.checked_mul(indices.len() - i)
.and_then(|count| count.checked_add(c))
});
(count.unwrap_or(usize::MAX), count)
}
Self::End => (0, Some(0)),
}
}
}