Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for NaN handling #19

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ edition = "2021"
description = "Shared utilities for development (tests & benchmarks)"

[dependencies]
num-traits = { version = "0.2.15", default-features = false }
varon marked this conversation as resolved.
Show resolved Hide resolved
rand = { version = "0.7.2", default-features = false }
rand_distr = { version = "0.2.2", default-features = false }
23 changes: 19 additions & 4 deletions dev_utils/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,36 @@ use std::ops::{Add, Sub};

use rand::{thread_rng, Rng};
use rand_distr::Uniform;
use num_traits::float::FloatCore;

// random array that samples between min and max of T
pub fn get_random_array<T>(n: usize, min_value: T, max_value: T) -> Vec<T>
where
T: Copy + rand::distributions::uniform::SampleUniform,
where
T: Copy + rand::distributions::uniform::SampleUniform,
jvdd marked this conversation as resolved.
Show resolved Hide resolved
{
let rng = thread_rng();
let uni = Uniform::new_inclusive(min_value, max_value);
rng.sample_iter(uni).take(n).collect()
}

/// Inserts NaN values into the provided array at the specified frequency.
/// * `name` - normalized_frequency - probability [0-1] of a NaN being inserted for each value.
pub fn insert_nans<T>(values: &mut Vec<T>, normalized_frequency: f32)
where
T: FloatCore,
varon marked this conversation as resolved.
Show resolved Hide resolved
{
let mut rng = thread_rng();
for i in 0..values.len() {
if normalized_frequency > rng.gen::<f32>() {
values[i] = FloatCore::nan();
}
}
}

// worst case array that alternates between increasing max and decreasing min values
pub fn get_worst_case_array<T>(n: usize, step: T) -> Vec<T>
where
T: Copy + Default + Sub<Output = T> + Add<Output = T>,
where
T: Copy + Default + Sub<Output=T> + Add<Output=T>,
{
let mut arr: Vec<T> = Vec::with_capacity(n);
let mut min_value: T = Default::default();
Expand Down
43 changes: 39 additions & 4 deletions tests/argminmax_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,40 @@ fn test_argminmax_vec() {
assert_eq!(max, ARRAY_LENGTH - 1);
}


#[test]
fn test_argminmax_vec_with_nans() {
let mut data: Vec<f32> = (0..ARRAY_LENGTH).map(|x| x as f32).collect();
utils::insert_nans(&mut data, 0.05);
// Test owned vec
let (min, max) = data.argminmax();
assert_eq!(min, 0);
assert_eq!(max, ARRAY_LENGTH - 1);
// Test borrowed vec
let (min, max) = (&data).argminmax();
assert_eq!(min, 0);
assert_eq!(max, ARRAY_LENGTH - 1);
Comment on lines +51 to +59
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test passes almost all the time, at least on the architecture I'm using.
This likely has to do with assumptions about the sorting of the data.

}


#[test]
fn test_argminmax_vec_with_start_end_nans() {

let mut data: Vec<f32> = (0..100).map(|x| x as f32).collect();
data[0] = f32::NAN;
data[100-1] = f32::NAN;
varon marked this conversation as resolved.
Show resolved Hide resolved
// Test owned vec
let (min, max) = data.argminmax();
assert_eq!(min, 1);
assert_eq!(max, 100 - 2);
// Test borrowed vec
let (min, max) = (&data).argminmax();
assert_eq!(min, 1);
assert_eq!(max, 100 - 2);
}



#[cfg(feature = "ndarray")]
#[test]
fn test_argminmax_ndarray() {
Expand Down Expand Up @@ -94,14 +128,14 @@ fn test_argminmax_many_random_runs() {
let (min_vec, max_vec) = data.argminmax();
// Array1
#[cfg(feature = "ndarray")]
let array: Array1<f32> = Array1::from_vec(slice.to_vec());
let array: Array1<f32> = Array1::from_vec(slice.to_vec());
#[cfg(feature = "ndarray")]
let (min_array, max_array) = array.argminmax();
let (min_array, max_array) = array.argminmax();
// Arrow
#[cfg(feature = "arrow")]
let arrow: Float32Array = Float32Array::from(slice.to_vec());
let arrow: Float32Array = Float32Array::from(slice.to_vec());
#[cfg(feature = "arrow")]
let (min_arrow, max_arrow) = arrow.argminmax();
let (min_arrow, max_arrow) = arrow.argminmax();
jvdd marked this conversation as resolved.
Show resolved Hide resolved
// Assert
assert_eq!(min_slice, min_vec);
assert_eq!(max_slice, max_vec);
Expand All @@ -115,3 +149,4 @@ fn test_argminmax_many_random_runs() {
assert_eq!(max_slice, max_arrow);
}
}