From 9bd778d4ddcc000c053256c45027793a58c17886 Mon Sep 17 00:00:00 2001 From: varon Date: Sat, 4 Feb 2023 23:16:47 +0200 Subject: [PATCH] Add tests for NaN handling --- dev_utils/Cargo.toml | 1 + dev_utils/src/utils.rs | 23 ++++++++++++++++++---- tests/argminmax_test.rs | 43 +++++++++++++++++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 8 deletions(-) diff --git a/dev_utils/Cargo.toml b/dev_utils/Cargo.toml index b2fa747..7f4f17e 100644 --- a/dev_utils/Cargo.toml +++ b/dev_utils/Cargo.toml @@ -6,5 +6,6 @@ edition = "2021" description = "Shared utilities for development (tests & benchmarks)" [dependencies] +num-traits = { version = "0.2.15", default-features = false } rand = { version = "0.7.2", default-features = false } rand_distr = { version = "0.2.2", default-features = false } diff --git a/dev_utils/src/utils.rs b/dev_utils/src/utils.rs index 2d8383f..26efa10 100644 --- a/dev_utils/src/utils.rs +++ b/dev_utils/src/utils.rs @@ -2,21 +2,36 @@ use std::ops::{Add, Sub}; use rand::{thread_rng, Rng}; use rand_distr::Uniform; +use num_traits::float::FloatCore; // random array that samples between min and max of T pub fn get_random_array(n: usize, min_value: T, max_value: T) -> Vec -where - T: Copy + rand::distributions::uniform::SampleUniform, + where + T: Copy + rand::distributions::uniform::SampleUniform, { let rng = thread_rng(); let uni = Uniform::new_inclusive(min_value, max_value); rng.sample_iter(uni).take(n).collect() } +/// Inserts NaN values into the provided array at the specified frequency. +/// * `name` - normalized_frequency - probability [0-1] of a NaN being inserted for each value. +pub fn insert_nans(values: &mut Vec, normalized_frequency: f32) + where + T: FloatCore, +{ + let mut rng = thread_rng(); + for i in 0..values.len() { + if normalized_frequency > rng.gen::() { + values[i] = FloatCore::nan(); + } + } +} + // worst case array that alternates between increasing max and decreasing min values pub fn get_worst_case_array(n: usize, step: T) -> Vec -where - T: Copy + Default + Sub + Add, + where + T: Copy + Default + Sub + Add, { let mut arr: Vec = Vec::with_capacity(n); let mut min_value: T = Default::default(); diff --git a/tests/argminmax_test.rs b/tests/argminmax_test.rs index ebd3786..e1a3cd2 100644 --- a/tests/argminmax_test.rs +++ b/tests/argminmax_test.rs @@ -44,6 +44,40 @@ fn test_argminmax_vec() { assert_eq!(max, ARRAY_LENGTH - 1); } + +#[test] +fn test_argminmax_vec_with_nans() { + let mut data: Vec = (0..ARRAY_LENGTH).map(|x| x as f32).collect(); + utils::insert_nans(&mut data, 0.05); + // Test owned vec + let (min, max) = data.argminmax(); + assert_eq!(min, 0); + assert_eq!(max, ARRAY_LENGTH - 1); + // Test borrowed vec + let (min, max) = (&data).argminmax(); + assert_eq!(min, 0); + assert_eq!(max, ARRAY_LENGTH - 1); +} + + +#[test] +fn test_argminmax_vec_with_start_end_nans() { + + let mut data: Vec = (0..100).map(|x| x as f32).collect(); + data[0] = f32::NAN; + data[100-1] = f32::NAN; + // Test owned vec + let (min, max) = data.argminmax(); + assert_eq!(min, 1); + assert_eq!(max, 100 - 2); + // Test borrowed vec + let (min, max) = (&data).argminmax(); + assert_eq!(min, 1); + assert_eq!(max, 100 - 2); +} + + + #[cfg(feature = "ndarray")] #[test] fn test_argminmax_ndarray() { @@ -94,14 +128,14 @@ fn test_argminmax_many_random_runs() { let (min_vec, max_vec) = data.argminmax(); // Array1 #[cfg(feature = "ndarray")] - let array: Array1 = Array1::from_vec(slice.to_vec()); + let array: Array1 = Array1::from_vec(slice.to_vec()); #[cfg(feature = "ndarray")] - let (min_array, max_array) = array.argminmax(); + let (min_array, max_array) = array.argminmax(); // Arrow #[cfg(feature = "arrow")] - let arrow: Float32Array = Float32Array::from(slice.to_vec()); + let arrow: Float32Array = Float32Array::from(slice.to_vec()); #[cfg(feature = "arrow")] - let (min_arrow, max_arrow) = arrow.argminmax(); + let (min_arrow, max_arrow) = arrow.argminmax(); // Assert assert_eq!(min_slice, min_vec); assert_eq!(max_slice, max_vec); @@ -115,3 +149,4 @@ fn test_argminmax_many_random_runs() { assert_eq!(max_slice, max_arrow); } } +