Skip to content

Commit

Permalink
allow user to set ParticleSwarm's random number generator (argmin-rs#383
Browse files Browse the repository at this point in the history
)

* allow user to set ParticleSwarm random number generator

* use R for Random and make default nondeterministic

use T where R is not possible

* update comment on default rng generator

Co-authored-by: Stefan Kroboth <stefan.kroboth@gmail.com>

---------

Co-authored-by: Stefan Kroboth <stefan.kroboth@gmail.com>
  • Loading branch information
jonboh and stefan-k authored Jan 7, 2024
1 parent ae0279b commit 220e0d5
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 47 deletions.
3 changes: 2 additions & 1 deletion argmin-math/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ mod vec;
pub use crate::vec::*;

use anyhow::Error;
use rand::Rng;

/// Dot/scalar product of `T` and `self`
pub trait ArgminDot<T, U> {
Expand Down Expand Up @@ -340,7 +341,7 @@ pub trait ArgminInv<T> {
/// Create a random number
pub trait ArgminRandom {
/// Get a random element between min and max,
fn rand_from_range(min: &Self, max: &Self) -> Self;
fn rand_from_range<R: Rng>(min: &Self, max: &Self, rng: &mut R) -> Self;
}

/// Minimum and Maximum of type `T`
Expand Down
17 changes: 10 additions & 7 deletions argmin-math/src/nalgebra_m/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ where
DefaultAllocator: Allocator<N, R, C>,
{
#[inline]
fn rand_from_range(min: &Self, max: &Self) -> OMatrix<N, R, C> {
fn rand_from_range<T: Rng>(min: &Self, max: &Self, rng: &mut T) -> OMatrix<N, R, C> {
assert!(!min.is_empty());
assert_eq!(min.shape(), max.shape());

let mut rng = rand::thread_rng();

Self::from_iterator_generic(
R::from_usize(min.nrows()),
C::from_usize(min.ncols()),
Expand All @@ -53,6 +51,7 @@ mod tests {
use super::*;
use nalgebra::{Matrix2x3, Vector3};
use paste::item;
use rand::SeedableRng;

macro_rules! make_test {
($t:ty) => {
Expand All @@ -61,7 +60,8 @@ mod tests {
fn [<test_random_vec_ $t>]() {
let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
let b = Vector3::new(2 as $t, 3 as $t, 4 as $t);
let random = Vector3::<$t>::rand_from_range(&a, &b);
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let random = Vector3::<$t>::rand_from_range(&a, &b, &mut rng);
for i in 0..3 {
assert!(random[i] >= a[i]);
assert!(random[i] <= b[i]);
Expand All @@ -74,7 +74,8 @@ mod tests {
fn [<test_random_vec_equal $t>]() {
let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
let b = Vector3::new(1 as $t, 2 as $t, 3 as $t);
let random = Vector3::<$t>::rand_from_range(&a, &b);
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let random = Vector3::<$t>::rand_from_range(&a, &b, &mut rng);
for i in 0..3 {
assert!((random[i] as f64 - a[i] as f64).abs() < std::f64::EPSILON);
assert!((random[i] as f64 - b[i] as f64).abs() < std::f64::EPSILON);
Expand All @@ -87,7 +88,8 @@ mod tests {
fn [<test_random_vec_reverse_ $t>]() {
let b = Vector3::new(1 as $t, 2 as $t, 3 as $t);
let a = Vector3::new(2 as $t, 3 as $t, 4 as $t);
let random = Vector3::<$t>::rand_from_range(&a, &b);
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let random = Vector3::<$t>::rand_from_range(&a, &b, &mut rng);
for i in 0..3 {
assert!(random[i] >= b[i]);
assert!(random[i] <= a[i]);
Expand All @@ -106,7 +108,8 @@ mod tests {
2 as $t, 4 as $t, 6 as $t,
3 as $t, 5 as $t, 7 as $t
);
let random = Matrix2x3::<$t>::rand_from_range(&a, &b);
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let random = Matrix2x3::<$t>::rand_from_range(&a, &b, &mut rng);
for i in 0..3 {
for j in 0..2 {
assert!(random[(j, i)] >= a[(j, i)]);
Expand Down
17 changes: 8 additions & 9 deletions argmin-math/src/ndarray_m/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use rand::Rng;
use rand::{Rng, SeedableRng};

use crate::ArgminRandom;

macro_rules! make_random {
($t:ty) => {
impl ArgminRandom for ndarray::Array1<$t> {
fn rand_from_range(min: &Self, max: &Self) -> ndarray::Array1<$t> {
fn rand_from_range<R: Rng>(min: &Self, max: &Self, rng: &mut R) -> ndarray::Array1<$t> {
assert!(!min.is_empty());
assert_eq!(min.len(), max.len());

let mut rng = rand::thread_rng();

ndarray::Array1::from_iter(min.iter().zip(max.iter()).map(|(a, b)| {
// Do not require a < b:

Expand All @@ -35,12 +33,10 @@ macro_rules! make_random {
}

impl ArgminRandom for ndarray::Array2<$t> {
fn rand_from_range(min: &Self, max: &Self) -> ndarray::Array2<$t> {
fn rand_from_range<R: Rng>(min: &Self, max: &Self, rng: &mut R) -> ndarray::Array2<$t> {
assert!(!min.is_empty());
assert_eq!(min.raw_dim(), max.raw_dim());

let mut rng = rand::thread_rng();

ndarray::Array2::from_shape_fn(min.raw_dim(), |(i, j)| {
let a = min.get((i, j)).unwrap();
let b = max.get((i, j)).unwrap();
Expand Down Expand Up @@ -78,6 +74,7 @@ mod tests {
use super::*;
use ndarray::{array, Array1, Array2};
use paste::item;
use rand::SeedableRng;

macro_rules! make_test {
($t:ty) => {
Expand All @@ -86,7 +83,8 @@ mod tests {
fn [<test_random_vec_ $t>]() {
let a = array![1 as $t, 2 as $t, 4 as $t];
let b = array![2 as $t, 3 as $t, 5 as $t];
let random = Array1::<$t>::rand_from_range(&a, &b);
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let random = Array1::<$t>::rand_from_range(&a, &b, &mut rng);
for i in 0..3usize {
assert!(random[i] >= a[i]);
assert!(random[i] <= b[i]);
Expand All @@ -105,7 +103,8 @@ mod tests {
[2 as $t, 3 as $t, 5 as $t],
[3 as $t, 4 as $t, 6 as $t]
];
let random = Array2::<$t>::rand_from_range(&a, &b);
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let random = Array2::<$t>::rand_from_range(&a, &b, &mut rng);
for i in 0..3 {
for j in 0..2 {
assert!(random[(j, i)] >= a[(j, i)]);
Expand Down
8 changes: 5 additions & 3 deletions argmin-math/src/primitives/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ macro_rules! make_random {
($t:ty) => {
impl ArgminRandom for $t {
#[inline]
fn rand_from_range(min: &Self, max: &Self) -> $t {
rand::thread_rng().gen_range(*min..*max)
fn rand_from_range<R: Rng>(min: &Self, max: &Self, rng: &mut R) -> $t {
rng.gen_range(*min..*max)
}
}
};
Expand All @@ -36,6 +36,7 @@ make_random!(usize);
mod tests {
use super::*;
use paste::item;
use rand::SeedableRng;

macro_rules! make_test {
($t:ty) => {
Expand All @@ -44,7 +45,8 @@ mod tests {
fn [<test_random_vec_ $t>]() {
let a = 1 as $t;
let b = 2 as $t;
let random = $t::rand_from_range(&a, &b);
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let random = $t::rand_from_range(&a, &b, &mut rng);
assert!(random >= a);
assert!(random <= b);
}
Expand Down
15 changes: 8 additions & 7 deletions argmin-math/src/vec/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ use rand::Rng;
macro_rules! make_random {
($t:ty) => {
impl ArgminRandom for Vec<$t> {
fn rand_from_range(min: &Self, max: &Self) -> Vec<$t> {
fn rand_from_range<R: Rng>(min: &Self, max: &Self, rng: &mut R) -> Vec<$t> {
assert!(!min.is_empty());
assert_eq!(min.len(), max.len());

let mut rng = rand::thread_rng();

min.iter()
.zip(max.iter())
.map(|(a, b)| {
Expand All @@ -37,12 +35,12 @@ macro_rules! make_random {
}

impl ArgminRandom for Vec<Vec<$t>> {
fn rand_from_range(min: &Self, max: &Self) -> Vec<Vec<$t>> {
fn rand_from_range<R: Rng>(min: &Self, max: &Self, rng: &mut R) -> Vec<Vec<$t>> {
assert!(!min.is_empty());
assert_eq!(min.len(), max.len());
min.iter()
.zip(max.iter())
.map(|(a, b)| Vec::<$t>::rand_from_range(a, b))
.map(|(a, b)| Vec::<$t>::rand_from_range(a, b, rng))
.collect()
}
}
Expand All @@ -66,6 +64,7 @@ make_random!(usize);
mod tests {
use super::*;
use paste::item;
use rand::SeedableRng;

macro_rules! make_test {
($t:ty) => {
Expand All @@ -74,7 +73,8 @@ mod tests {
fn [<test_random_vec_ $t>]() {
let a = vec![1 as $t, 2 as $t, 4 as $t];
let b = vec![2 as $t, 3 as $t, 5 as $t];
let random = Vec::<$t>::rand_from_range(&a, &b);
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let random = Vec::<$t>::rand_from_range(&a, &b, &mut rng);
for i in 0..3usize {
assert!(random[i] >= a[i]);
assert!(random[i] <= b[i]);
Expand All @@ -93,7 +93,8 @@ mod tests {
vec![2 as $t, 3 as $t, 5 as $t],
vec![3 as $t, 4 as $t, 6 as $t]
];
let random = Vec::<Vec<$t>>::rand_from_range(&a, &b);
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let random = Vec::<Vec<$t>>::rand_from_range(&a, &b, &mut rng);
for i in 0..3 {
for j in 0..2 {
assert!(random[j][i] >= a[j][i]);
Expand Down
Loading

0 comments on commit 220e0d5

Please sign in to comment.