Skip to content

Commit

Permalink
kk: replace panics with errors
Browse files Browse the repository at this point in the history
  • Loading branch information
hhirtz committed Mar 23, 2022
1 parent 0b8bf5c commit e6b14ec
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions src/algorithms/kk.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::Error;
use std::collections::BinaryHeap;
use std::ops::Sub;
use std::ops::SubAssign;
Expand All @@ -9,18 +10,14 @@ use num::Zero;
/// # Differences with the k-partitioning implementation
///
/// This function has better performance than [kk] called with `num_parts == 2`.
fn kk_bipart<T>(partition: &mut [usize], weights: impl IntoIterator<Item = T>)
fn kk_bipart<T>(partition: &mut [usize], weights: impl Iterator<Item = T>)
where
T: Ord + Sub<Output = T>,
{
let mut weights: BinaryHeap<(T, usize)> = weights
.into_iter()
.zip(0..) // Keep track of the weights' indicies
.collect();
assert_eq!(partition.len(), weights.len());
if weights.is_empty() {
return;
}

// Core algorithm: find the imbalance of the partition.
// "opposites" is built in this loop to backtrack the solution. It tracks weights that must end
Expand Down Expand Up @@ -54,30 +51,15 @@ where
fn kk<T, I>(partition: &mut [usize], weights: I, num_parts: usize)
where
T: Zero + Ord + Sub<Output = T> + SubAssign + Copy,
I: IntoIterator<Item = T>,
<I as IntoIterator>::IntoIter: ExactSizeIterator,
I: Iterator<Item = T> + ExactSizeIterator,
{
let weights = weights.into_iter();
let num_weights = weights.len();
assert_eq!(partition.len(), num_weights);
if num_weights == 0 {
return;
}
if num_parts < 2 {
return;
}
if num_parts == 2 {
// The bi-partitioning is a special case that can be handled faster than
// the general case.
return kk_bipart(partition, weights);
}

// Initialize "m", a "k*num_weights" matrix whose first column is "weights".
let weight_count = weights.len();
let mut m: BinaryHeap<Vec<(T, usize)>> = weights
.zip(0..)
.map(|(w, id)| {
let mut v: Vec<(T, usize)> = (0..num_parts)
.map(|p| (T::zero(), num_weights * p + id))
.map(|p| (T::zero(), weight_count * p + id))
.collect();
v[0].0 = w;
v
Expand All @@ -88,7 +70,7 @@ where
// largest weights in two different parts, the largest weight of each row is put into the same
// part as the smallest one, and so on.

let mut opposites = Vec::with_capacity(num_weights);
let mut opposites = Vec::with_capacity(weight_count);
while 2 <= m.len() {
let a = m.pop().unwrap();
let b = m.pop().unwrap();
Expand Down Expand Up @@ -119,7 +101,7 @@ where
// Backtracking. Same as the bi-partitioning case.

// parts = [ [m0i] for m0i in m[0] ]
let mut parts: Vec<usize> = vec![0; num_parts * num_weights];
let mut parts: Vec<usize> = vec![0; num_parts * weight_count];
let imbalance = m.pop().unwrap(); // first and last element of "m".
for (i, w) in imbalance.into_iter().enumerate() {
// Put each remaining element in a different part.
Expand Down Expand Up @@ -160,14 +142,30 @@ where
W::Item: Zero + Ord + Sub<Output = W::Item> + SubAssign + Copy,
{
type Metadata = ();
type Error = std::convert::Infallible;
type Error = Error;

fn partition(
&mut self,
part_ids: &mut [usize],
weights: W,
) -> Result<Self::Metadata, Self::Error> {
kk(part_ids, weights, self.part_count);
if self.part_count < 2 || part_ids.len() < 2 {
return Ok(());
}
let weights = weights.into_iter();
if weights.len() != part_ids.len() {
return Err(Error::InputLenMismatch {
expected: part_ids.len(),
actual: weights.len(),
});
}
if self.part_count == 2 {
// The bi-partitioning is a special case that can be handled faster
// than the general case.
kk_bipart(part_ids, weights);
} else {
kk(part_ids, weights, self.part_count);
}
Ok(())
}
}

0 comments on commit e6b14ec

Please sign in to comment.