Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WeightedIndex: Make it possible to update a subset of weights #866

Merged
merged 9 commits into from
Aug 22, 2019
38 changes: 38 additions & 0 deletions benches/weighted.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2019 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#![feature(test)]

extern crate test;

const RAND_BENCH_N: u64 = 1000;

use test::Bencher;
use rand::Rng;
use rand::distributions::WeightedIndex;

#[bench]
fn weighted_index_creation(b: &mut Bencher) {
let mut rng = rand::thread_rng();
b.iter(|| {
let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7];
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
rng.sample(distr)
})
}

#[bench]
fn weighted_index_modification(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
b.iter(|| {
distr.update_weights(&[(2, &4), (5, &1)]).unwrap();
rng.sample(&distr)
})
}
75 changes: 73 additions & 2 deletions src/distributions/weighted/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ use core::fmt;
#[derive(Debug, Clone)]
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
cumulative_weights: Vec<X>,
total_weight: X,
weight_distribution: X::Sampler,
}

Expand Down Expand Up @@ -125,9 +126,63 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
if total_weight == zero {
return Err(WeightedError::AllWeightsZero);
}
let distr = X::Sampler::new(zero, total_weight);
let distr = X::Sampler::new(zero, total_weight.clone());

Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
Ok(WeightedIndex { cumulative_weights: weights, total_weight, weight_distribution: distr })
}

/// Update a subset of weights, without changing the number of weights.
///
/// Using this method instead of `new` might be more efficient if only a small number of
/// weights is modified. For weights that are `Copy`, no allocations are performed.
vks marked this conversation as resolved.
Show resolved Hide resolved
///
/// In case of error, `self` is not modified.
pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
where X: for<'a> ::core::ops::AddAssign<&'a X> +
for<'a> ::core::ops::SubAssign<&'a X> +
Clone +
Default {
let zero = <X as Default>::default();

let mut total_weight = self.total_weight.clone();

for &(i, w) in new_weights {
if *w < zero {
return Err(WeightedError::InvalidWeight);
}
if i >= self.cumulative_weights.len() {
return Err(WeightedError::TooMany);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be worth adding InvalidIndex, except that it's a breaking change. Perhaps do so in a separate PR which we don't land until we start preparing the next Rand version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I though about this as well. Will do once this is merged.

}

// Unfortunately, we will have to calculate the non-cumulative weight a second time, to
// avoid producing an invalid state of `self`.
let mut old_w = self.cumulative_weights[i].clone();
if i > 0 {
old_w -= &self.cumulative_weights[i - 1];
}

total_weight -= &old_w;
total_weight += w;
}
if total_weight == zero {
return Err(WeightedError::AllWeightsZero);
}

for &(i, w) in new_weights {
let mut old_w = self.cumulative_weights[i].clone();
if i > 0 {
old_w -= &self.cumulative_weights[i - 1];
}

for j in i..self.cumulative_weights.len() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is O(n*m) where n = cumulative_weights.len() - min_index; m = new_weights.len().

Instead we should sort the new_weights by index, then apply in-turn (like in new); this is O(m*log(m) + n).

Also, we can just take total_weight = cumulative_weights.last().unwrap().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead we should sort the new_weights by index, then apply in-turn (like in new); this is O(m*log(m) + n).

I'll look into this.

Also, we can just take total_weight = cumulative_weights.last().unwrap().

I don't think so, the last cumulative weight is not stored in the vector. Or are you saying we should change it such that it is?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, binary_search_by is happy to return an index one-past-the-last-item, therefore the final weight is not needed. (And we have motive for not including the final weight: it guarantees we will never exceed the last index of the input weights list.)

Then yes, we need to store either the last weight or the total as an extra field.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead we should sort the new_weights by index, then apply in-turn (like in new); this is O(m*log(m) + n).

I implemented that. It's a bit messy, because the the index type might be unsigned.

self.cumulative_weights[j] -= &old_w;
self.cumulative_weights[j] += w;
}
}
self.total_weight = total_weight;
self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());

Ok(())
}
}

Expand Down Expand Up @@ -201,6 +256,22 @@ mod test {
assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight);
assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight);
}

#[test]
fn test_update_weights() {
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
let total_weight = weights.iter().sum::<u32>();
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
assert_eq!(distr.total_weight, total_weight);

distr.update_weights(&[(2, &4), (5, &1)]).unwrap();
let expected_weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7];
let expected_total_weight = expected_weights.iter().sum::<u32>();
let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
assert_eq!(distr.total_weight, expected_total_weight);
assert_eq!(distr.total_weight, expected_distr.total_weight);
assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
}
}

/// Error type returned from `WeightedIndex::new`.
Expand Down