diff --git a/src/algorithms/kk.rs b/src/algorithms/kk.rs index 944fe701..60b970fc 100644 --- a/src/algorithms/kk.rs +++ b/src/algorithms/kk.rs @@ -1,3 +1,4 @@ +use super::Error; use std::collections::BinaryHeap; use std::ops::Sub; use std::ops::SubAssign; @@ -9,7 +10,7 @@ use num::Zero; /// # Differences with the k-partitioning implementation /// /// This function has better performance than [kk] called with `num_parts == 2`. -fn kk_bipart(partition: &mut [usize], weights: impl IntoIterator) +fn kk_bipart(partition: &mut [usize], weights: impl Iterator) where T: Ord + Sub, { @@ -17,10 +18,6 @@ where .into_iter() .zip(0..) // Keep track of the weights' indicies .collect(); - assert_eq!(partition.len(), weights.len()); - if weights.is_empty() { - return; - } // Core algorithm: find the imbalance of the partition. // "opposites" is built in this loop to backtrack the solution. It tracks weights that must end @@ -54,30 +51,15 @@ where fn kk(partition: &mut [usize], weights: I, num_parts: usize) where T: Zero + Ord + Sub + SubAssign + Copy, - I: IntoIterator, - ::IntoIter: ExactSizeIterator, + I: Iterator + ExactSizeIterator, { - let weights = weights.into_iter(); - let num_weights = weights.len(); - assert_eq!(partition.len(), num_weights); - if num_weights == 0 { - return; - } - if num_parts < 2 { - return; - } - if num_parts == 2 { - // The bi-partitioning is a special case that can be handled faster than - // the general case. - return kk_bipart(partition, weights); - } - // Initialize "m", a "k*num_weights" matrix whose first column is "weights". + let weight_count = weights.len(); let mut m: BinaryHeap> = weights .zip(0..) .map(|(w, id)| { let mut v: Vec<(T, usize)> = (0..num_parts) - .map(|p| (T::zero(), num_weights * p + id)) + .map(|p| (T::zero(), weight_count * p + id)) .collect(); v[0].0 = w; v @@ -88,7 +70,7 @@ where // largest weights in two different parts, the largest weight of each row is put into the same // part as the smallest one, and so on. - let mut opposites = Vec::with_capacity(num_weights); + let mut opposites = Vec::with_capacity(weight_count); while 2 <= m.len() { let a = m.pop().unwrap(); let b = m.pop().unwrap(); @@ -119,7 +101,7 @@ where // Backtracking. Same as the bi-partitioning case. // parts = [ [m0i] for m0i in m[0] ] - let mut parts: Vec = vec![0; num_parts * num_weights]; + let mut parts: Vec = vec![0; num_parts * weight_count]; let imbalance = m.pop().unwrap(); // first and last element of "m". for (i, w) in imbalance.into_iter().enumerate() { // Put each remaining element in a different part. @@ -160,14 +142,30 @@ where W::Item: Zero + Ord + Sub + SubAssign + Copy, { type Metadata = (); - type Error = std::convert::Infallible; + type Error = Error; fn partition( &mut self, part_ids: &mut [usize], weights: W, ) -> Result { - kk(part_ids, weights, self.part_count); + if self.part_count < 2 || part_ids.len() < 2 { + return Ok(()); + } + let weights = weights.into_iter(); + if weights.len() != part_ids.len() { + return Err(Error::InputLenMismatch { + expected: part_ids.len(), + actual: weights.len(), + }); + } + if self.part_count == 2 { + // The bi-partitioning is a special case that can be handled faster + // than the general case. + kk_bipart(part_ids, weights); + } else { + kk(part_ids, weights, self.part_count); + } Ok(()) } }