Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚧 POC - support NaNs for SSE & AVX2 f32 #18

Closed
wants to merge 12 commits into from
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ codspeed-criterion-compat = "1.0.1"
criterion = "0.3.1"
dev_utils = { path = "dev_utils" }

[profile.release]
lto = true


[[bench]]
name = "bench_f16"
Expand Down
8 changes: 4 additions & 4 deletions benches/bench_f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fn minmax_f32_random_array_long(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_random_long_f32", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data)) })
});
Expand Down Expand Up @@ -67,7 +67,7 @@ fn minmax_f32_random_array_short(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_random_short_f32", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data)) })
});
Expand Down Expand Up @@ -108,7 +108,7 @@ fn minmax_f32_worst_case_array_long(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_worst_long_f32", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data)) })
});
Expand Down Expand Up @@ -149,7 +149,7 @@ fn minmax_f32_worst_case_array_short(c: &mut Criterion) {
});
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx") {
if is_x86_feature_detected!("avx2") {
c.bench_function("avx_worst_short_f32", |b| {
b.iter(|| unsafe { AVX2::argminmax(black_box(data)) })
});
Expand Down
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ macro_rules! impl_argminmax {
return unsafe { AVX512::argminmax(self) }
} else if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::argminmax(self) }
} else if is_x86_feature_detected!("avx") & (<$t>::NB_BITS >= 32) & (<$t>::IS_FLOAT == true) {
// TODO: current NaN handling requires avx2
} else if is_x86_feature_detected!("avx") & (<$t>::NB_BITS >= 64) & (<$t>::IS_FLOAT == true) {
// f32 and f64 do not require avx2
return unsafe { AVX2::argminmax(self) }
// SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers
// // SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers
// // } else if is_x86_feature_detected!("sse4.2") & (<$t>::NB_BITS == 64) & (<$t>::IS_FLOAT == false) {
// // SSE4.2 is needed for comparing 64-bit integers
// return unsafe { SSE::argminmax(self) }
Expand Down
10 changes: 9 additions & 1 deletion src/scalar/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@ pub fn scalar_argminmax<T: Copy + PartialOrd>(arr: &[T]) -> (usize, usize) {
let mut high: T = unsafe { *arr.get_unchecked(high_index) };
for i in 0..arr.len() {
let v: T = unsafe { *arr.get_unchecked(i) };
if v < low {
if v != v {
// TODO: optimize this
// Handle NaNs: if value is NaN, than return index of that value
return (i, i);
// low = v;
// low_index = i;
// high = v;
// high_index = i;
} else if v < low {
low = v;
low_index = i;
} else if v > high {
Expand Down
8 changes: 4 additions & 4 deletions src/simd/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ pub trait SIMD<
// Self::_mm_prefetch(arr.as_ptr().add(start));
let (min_index_, min_value_, max_index_, max_value_) =
Self::_core_argminmax(&arr[start..start + dtype_max]);
if min_value_ < min_value {
if min_value_ < min_value || min_value_ != min_value_ {
min_index = start + min_index_;
min_value = min_value_;
}
if max_value_ > max_value {
if max_value_ > max_value || max_value_ != max_value_ {
max_index = start + max_index_;
max_value = max_value_;
}
Expand All @@ -137,11 +137,11 @@ pub trait SIMD<
// Self::_mm_prefetch(arr.as_ptr().add(start));
let (min_index_, min_value_, max_index_, max_value_) =
Self::_core_argminmax(&arr[start..]);
if min_value_ < min_value {
if min_value_ < min_value || min_value_ != min_value_ {
min_index = start + min_index_;
min_value = min_value_;
}
if max_value_ > max_value {
if max_value_ > max_value || max_value_ != max_value_ {
max_index = start + max_index_;
max_value = max_value_;
}
Expand Down
Loading