diff --git a/Cargo.toml b/Cargo.toml index 5581ec4..3e1f4e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,9 @@ codspeed-criterion-compat = "1.0.1" criterion = "0.3.1" dev_utils = { path = "dev_utils" } +[profile.release] +lto = true + [[bench]] name = "bench_f16" diff --git a/benches/bench_f32.rs b/benches/bench_f32.rs index f71b815..9b06a64 100644 --- a/benches/bench_f32.rs +++ b/benches/bench_f32.rs @@ -26,7 +26,7 @@ fn minmax_f32_random_array_long(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_random_long_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) }); @@ -67,7 +67,7 @@ fn minmax_f32_random_array_short(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_random_short_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) }); @@ -108,7 +108,7 @@ fn minmax_f32_worst_case_array_long(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_worst_long_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) }); @@ -149,7 +149,7 @@ fn minmax_f32_worst_case_array_short(c: &mut Criterion) { }); } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { + if is_x86_feature_detected!("avx2") { c.bench_function("avx_worst_short_f32", |b| { b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) }); diff --git a/src/lib.rs b/src/lib.rs index ee78409..6c9c170 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,10 +125,11 @@ macro_rules! impl_argminmax { return unsafe { AVX512::argminmax(self) } } else if is_x86_feature_detected!("avx2") { return unsafe { AVX2::argminmax(self) } - } else if is_x86_feature_detected!("avx") & (<$t>::NB_BITS >= 32) & (<$t>::IS_FLOAT == true) { + // TODO: current NaN handling requires avx2 + } else if is_x86_feature_detected!("avx") & (<$t>::NB_BITS >= 64) & (<$t>::IS_FLOAT == true) { // f32 and f64 do not require avx2 return unsafe { AVX2::argminmax(self) } - // SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers + // // SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers // // } else if is_x86_feature_detected!("sse4.2") & (<$t>::NB_BITS == 64) & (<$t>::IS_FLOAT == false) { // // SSE4.2 is needed for comparing 64-bit integers // return unsafe { SSE::argminmax(self) } diff --git a/src/scalar/generic.rs b/src/scalar/generic.rs index f16963c..134b335 100644 --- a/src/scalar/generic.rs +++ b/src/scalar/generic.rs @@ -22,7 +22,15 @@ pub fn scalar_argminmax(arr: &[T]) -> (usize, usize) { let mut high: T = unsafe { *arr.get_unchecked(high_index) }; for i in 0..arr.len() { let v: T = unsafe { *arr.get_unchecked(i) }; - if v < low { + if v != v { + // TODO: optimize this + // Handle NaNs: if value is NaN, than return index of that value + return (i, i); + // low = v; + // low_index = i; + // high = v; + // high_index = i; + } else if v < low { low = v; low_index = i; } else if v > high { diff --git a/src/simd/generic.rs b/src/simd/generic.rs index c5d708e..f64d2aa 100644 --- a/src/simd/generic.rs +++ b/src/simd/generic.rs @@ -122,11 +122,11 @@ pub trait SIMD< // Self::_mm_prefetch(arr.as_ptr().add(start)); let (min_index_, min_value_, max_index_, max_value_) = Self::_core_argminmax(&arr[start..start + dtype_max]); - if min_value_ < min_value { + if min_value_ < min_value || min_value_ != min_value_ { min_index = start + min_index_; min_value = min_value_; } - if max_value_ > max_value { + if max_value_ > max_value || max_value_ != max_value_ { max_index = start + max_index_; max_value = max_value_; } @@ -137,11 +137,11 @@ pub trait SIMD< // Self::_mm_prefetch(arr.as_ptr().add(start)); let (min_index_, min_value_, max_index_, max_value_) = Self::_core_argminmax(&arr[start..]); - if min_value_ < min_value { + if min_value_ < min_value || min_value_ != min_value_ { min_index = start + min_index_; min_value = min_value_; } - if max_value_ > max_value { + if max_value_ > max_value || max_value_ != max_value_ { max_index = start + max_index_; max_value = max_value_; } diff --git a/src/simd/simd_f32.rs b/src/simd/simd_f32.rs index c56b3ad..6e34902 100644 --- a/src/simd/simd_f32.rs +++ b/src/simd/simd_f32.rs @@ -9,6 +9,18 @@ use std::arch::x86::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; +use super::task::{max_index_value, min_index_value}; + +const XOR_VALUE: i32 = 0x7FFFFFFF; +const BIT_SHIFT: i32 = 31; + +#[inline(always)] +fn _ord_i32_to_f32(ord_i32: i32) -> f32 { + // TODO: more efficient transformation -> can be decreasing order as well + let v = ((ord_i32 >> BIT_SHIFT) & XOR_VALUE) ^ ord_i32; + unsafe { std::mem::transmute::(v) } +} + // ------------------------------------------ AVX2 ------------------------------------------ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] @@ -17,57 +29,95 @@ mod avx2 { use super::*; const LANE_SIZE: usize = AVX2::LANE_SIZE_32; + const XOR_MASK: __m256i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; + + #[inline(always)] + unsafe fn _f32_as_m256i_to_i32ord(f32_as_m256i: __m256i) -> __m256i { + // on a scalar: ((v >> 31) & 0x7FFFFFFF) ^ v + let sign_bit_shifted = _mm256_srai_epi32(f32_as_m256i, BIT_SHIFT); + let sign_bit_masked = _mm256_and_si256(sign_bit_shifted, XOR_MASK); + _mm256_xor_si256(sign_bit_masked, f32_as_m256i) + } - impl SIMD for AVX2 { - const INITIAL_INDEX: __m256 = unsafe { - std::mem::transmute([ - 0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, - ]) - }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + #[inline(always)] + unsafe fn _reg_to_i32_arr(reg: __m256i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) + } + + impl SIMD for AVX2 { + const INITIAL_INDEX: __m256i = + unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32]) }; + const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256) -> [f32; LANE_SIZE] { - std::mem::transmute::<__m256, [f32; LANE_SIZE]>(reg) + unsafe fn _reg_to_arr(_: __m256i) -> [f32; LANE_SIZE] { + unimplemented!() } #[inline(always)] - unsafe fn _mm_loadu(data: *const f32) -> __m256 { - _mm256_loadu_ps(data as *const f32) + unsafe fn _mm_loadu(data: *const f32) -> __m256i { + _f32_as_m256i_to_i32ord(_mm256_loadu_si256(data as *const __m256i)) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256 { - _mm256_set1_ps(a as f32) + unsafe fn _mm_set1(a: usize) -> __m256i { + _mm256_set1_epi32(a as i32) } #[inline(always)] - unsafe fn _mm_add(a: __m256, b: __m256) -> __m256 { - _mm256_add_ps(a, b) + unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { + _mm256_add_epi32(a, b) } #[inline(always)] - unsafe fn _mm_cmpgt(a: __m256, b: __m256) -> __m256 { - _mm256_cmp_ps(a, b, _CMP_GT_OQ) + unsafe fn _mm_cmpgt(a: __m256i, b: __m256i) -> __m256i { + _mm256_cmpgt_epi32(a, b) } #[inline(always)] - unsafe fn _mm_cmplt(a: __m256, b: __m256) -> __m256 { - _mm256_cmp_ps(b, a, _CMP_GT_OQ) + unsafe fn _mm_cmplt(a: __m256i, b: __m256i) -> __m256i { + _mm256_cmpgt_epi32(b, a) } #[inline(always)] - unsafe fn _mm_blendv(a: __m256, b: __m256, mask: __m256) -> __m256 { - _mm256_blendv_ps(a, b, mask) + unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) } // ------------------------------------ ARGMINMAX -------------------------------------- - #[target_feature(enable = "avx")] + #[target_feature(enable = "avx2")] unsafe fn argminmax(data: &[f32]) -> (usize, usize) { Self::_argminmax(data) } + + #[inline(always)] + unsafe fn _get_min_max_index_value( + index_low: __m256i, + values_low: __m256i, + index_high: __m256i, + values_high: __m256i, + ) -> (usize, f32, usize, f32) { + // Get the results as arrays + let index_low_arr = _reg_to_i32_arr(index_low); + let values_low_arr = _reg_to_i32_arr(values_low); + let index_high_arr = _reg_to_i32_arr(index_high); + let values_high_arr = _reg_to_i32_arr(values_high); + // Find the min and max values and their indices + let (min_index, min_value) = min_index_value(&index_low_arr, &values_low_arr); + let (max_index, max_value) = max_index_value(&index_high_arr, &values_high_arr); + // Return the results - convert the ordinal ints back to floats + let min_value = _ord_i32_to_f32(min_value); + let max_value = _ord_i32_to_f32(max_value); + if min_value != min_value && max_value == max_value { + // min_value is the only NaN + return (min_index as usize, min_value, min_index as usize, min_value); + } else if min_value == min_value && max_value != max_value { + // max_value is the only NaN + return (max_index as usize, max_value, max_index as usize, max_value); + } + (min_index as usize, min_value, max_index as usize, max_value) + } } // ------------------------------------ TESTS -------------------------------------- @@ -86,7 +136,7 @@ mod avx2 { #[test] fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -101,7 +151,7 @@ mod avx2 { #[test] fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -126,9 +176,52 @@ mod avx2 { assert_eq!(argmax_simd_index, 1); } + #[test] + fn test_return_nan_index() { + if !is_x86_feature_detected!("avx2") { + return; + } + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f32(1027); + data[0] = std::f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: NaN is the last element + let mut data: Vec = get_array_f32(1027); + data[1026] = std::f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 3: NaN is somewhere in the middle element + let mut data: Vec = get_array_f32(1027); + data[123] = std::f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + } + #[test] fn test_no_overflow() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -143,7 +236,7 @@ mod avx2 { #[test] fn test_many_random_runs() { - if !is_x86_feature_detected!("avx") { + if !is_x86_feature_detected!("avx2") { return; } @@ -166,46 +259,58 @@ mod sse { use super::*; const LANE_SIZE: usize = SSE::LANE_SIZE_32; + const XOR_MASK: __m128i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; + + #[inline(always)] + unsafe fn _f32_as_m128i_to_i32ord(f32_as_m128i: __m128i) -> __m128i { + // on a scalar: ((v >> 31) & 0x7FFFFFFF) ^ v + let sign_bit_shifted = _mm_srai_epi32(f32_as_m128i, BIT_SHIFT); + let sign_bit_masked = _mm_and_si128(sign_bit_shifted, XOR_MASK); + _mm_xor_si128(sign_bit_masked, f32_as_m128i) + } - impl SIMD for SSE { - const INITIAL_INDEX: __m128 = - unsafe { std::mem::transmute([0.0f32, 1.0f32, 2.0f32, 3.0f32]) }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + #[inline(always)] + unsafe fn _reg_to_i32_arr(reg: __m128i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) + } + + impl SIMD for SSE { + const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; + const MAX_INDEX: usize = i32::MAX as usize; #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128) -> [f32; LANE_SIZE] { - std::mem::transmute::<__m128, [f32; LANE_SIZE]>(reg) + unsafe fn _reg_to_arr(_: __m128i) -> [f32; LANE_SIZE] { + unimplemented!() } #[inline(always)] - unsafe fn _mm_loadu(data: *const f32) -> __m128 { - _mm_loadu_ps(data as *const f32) + unsafe fn _mm_loadu(data: *const f32) -> __m128i { + _f32_as_m128i_to_i32ord(_mm_loadu_si128(data as *const __m128i)) } #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128 { - _mm_set1_ps(a as f32) + unsafe fn _mm_set1(a: usize) -> __m128i { + _mm_set1_epi32(a as i32) } #[inline(always)] - unsafe fn _mm_add(a: __m128, b: __m128) -> __m128 { - _mm_add_ps(a, b) + unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { + _mm_add_epi32(a, b) } #[inline(always)] - unsafe fn _mm_cmpgt(a: __m128, b: __m128) -> __m128 { - _mm_cmpgt_ps(a, b) + unsafe fn _mm_cmpgt(a: __m128i, b: __m128i) -> __m128i { + _mm_cmpgt_epi32(a, b) } #[inline(always)] - unsafe fn _mm_cmplt(a: __m128, b: __m128) -> __m128 { - _mm_cmplt_ps(a, b) + unsafe fn _mm_cmplt(a: __m128i, b: __m128i) -> __m128i { + _mm_cmplt_epi32(a, b) } #[inline(always)] - unsafe fn _mm_blendv(a: __m128, b: __m128, mask: __m128) -> __m128 { - _mm_blendv_ps(a, b, mask) + unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) } // ------------------------------------ ARGMINMAX -------------------------------------- @@ -214,6 +319,34 @@ mod sse { unsafe fn argminmax(data: &[f32]) -> (usize, usize) { Self::_argminmax(data) } + + #[inline(always)] + unsafe fn _get_min_max_index_value( + index_low: __m128i, + values_low: __m128i, + index_high: __m128i, + values_high: __m128i, + ) -> (usize, f32, usize, f32) { + // Get the results as arrays + let index_low_arr = _reg_to_i32_arr(index_low); + let values_low_arr = _reg_to_i32_arr(values_low); + let index_high_arr = _reg_to_i32_arr(index_high); + let values_high_arr = _reg_to_i32_arr(values_high); + // Find the min and max values and their indices + let (min_index, min_value) = min_index_value(&index_low_arr, &values_low_arr); + let (max_index, max_value) = max_index_value(&index_high_arr, &values_high_arr); + // Return the results - convert the ordinal ints back to floats + let min_value = _ord_i32_to_f32(min_value); + let max_value = _ord_i32_to_f32(max_value); + if min_value != min_value && max_value == max_value { + // min_value is the only NaN + return (min_index as usize, min_value, min_index as usize, min_value); + } else if min_value == min_value && max_value != max_value { + // max_value is the only NaN + return (max_index as usize, max_value, max_index as usize, max_value); + } + (min_index as usize, min_value, max_index as usize, max_value) + } } // ------------------------------------ TESTS -------------------------------------- @@ -264,6 +397,45 @@ mod sse { assert_eq!(argmax_simd_index, 1); } + #[test] + fn test_return_nan_index() { + // Case 1: NaN is the first element + let mut data: Vec = get_array_f32(1027); + data[0] = std::f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: NaN is the last element + let mut data: Vec = get_array_f32(1027); + data[1026] = std::f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 3: NaN is somewhere in the middle element + let mut data: Vec = get_array_f32(1027); + data[123] = std::f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + } + #[test] fn test_no_overflow() { let n: usize = 1 << 25; diff --git a/src/simd/task.rs b/src/simd/task.rs index 7ccc0b8..51b6de9 100644 --- a/src/simd/task.rs +++ b/src/simd/task.rs @@ -14,6 +14,9 @@ where assert!(!arr.is_empty()); // split_array should never return (None, None) match split_array(arr, lane_size) { (Some(sim), Some(rem)) => { + // Perform SIMD operation on the first part of the array + let sim_result = unsafe { core_argminmax(sim) }; + // Perform scalar operation on the remainder of the array let (rem_min_index, rem_max_index) = SCALAR::argminmax(rem); let rem_result = ( rem_min_index + sim.len(), @@ -21,8 +24,8 @@ where rem_max_index + sim.len(), rem[rem_max_index], ); - let sim_result = unsafe { core_argminmax(sim) }; - find_final_index_minmax(rem_result, sim_result) + // Select the final result + find_final_index_minmax(sim_result, rem_result) } (None, Some(rem)) => { let (rem_min_index, rem_max_index) = SCALAR::argminmax(rem); @@ -55,21 +58,45 @@ fn split_array(arr: &[T], lane_size: usize) -> (Option<&[T]>, Option<&[ } } +// TODO: if we support argmin & argmax (as both seperate funcs), this func should be +// broken up into two seperate functions. #[inline(always)] fn find_final_index_minmax( remainder_result: (usize, T, usize, T), simd_result: (usize, T, usize, T), ) -> (usize, usize) { - let min_result = match remainder_result.1.partial_cmp(&simd_result.1).unwrap() { - Ordering::Less => remainder_result.0, - Ordering::Equal => std::cmp::min(remainder_result.0, simd_result.0), - Ordering::Greater => simd_result.0, + let min_result = match simd_result.1.partial_cmp(&remainder_result.1) { + Some(Ordering::Less) => simd_result.0, + Some(Ordering::Equal) => std::cmp::min(simd_result.0, remainder_result.0), + Some(Ordering::Greater) => remainder_result.0, + // Handle NaNs: if value is NaN, than return index of that value + // This should prefer the simd result over the remainder result if both are NaN + None => { + // Instead of checking for NaN, we check if the value is not equal to itself + // This is because NaN != NaN + if simd_result.1 != simd_result.1 { + simd_result.0 + } else { + remainder_result.0 + } + } }; - let max_result = match simd_result.3.partial_cmp(&remainder_result.3).unwrap() { - Ordering::Less => remainder_result.2, - Ordering::Equal => std::cmp::min(remainder_result.2, simd_result.2), - Ordering::Greater => simd_result.2, + let max_result = match simd_result.3.partial_cmp(&remainder_result.3) { + Some(Ordering::Less) => remainder_result.2, + Some(Ordering::Equal) => std::cmp::min(remainder_result.2, simd_result.2), + Some(Ordering::Greater) => simd_result.2, + // Handle NaNs: if value is NaN, than return index of that value + // This should prefer the simd result over the remainder result if both are NaN + None => { + // Instead of checking for NaN, we check if the value is not equal to itself + // This is because NaN != NaN + if simd_result.3 != simd_result.3 { + simd_result.2 + } else { + remainder_result.2 + } + } }; (min_result, max_result)