Skip to content

Commit

Permalink
♻️ further decoupling of traits
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Feb 11, 2023
1 parent ddad0a5 commit 07a5e66
Show file tree
Hide file tree
Showing 14 changed files with 198 additions and 304 deletions.
12 changes: 12 additions & 0 deletions src/simd/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ pub trait SIMDInstructionSet {

// ----------------------------- x86_64 / x86 -----------------------------

// SIMDIstructionSet only implented for SSE
// -> necessary to auto implement SIMDCore & SIMDArgMinMax in generic.rs

pub struct SSE;
pub struct SSEFloatIgnoreNaN;
pub struct SSEFloatReturnNaN;
Expand All @@ -24,6 +27,9 @@ impl SIMDInstructionSet for SSE {
const REGISTER_SIZE: usize = 128;
}

// SIMDInstructionSet only implented for AVX2
// -> necessary to auto implement SIMDCore & SIMDArgMinMax in generic.rs

pub struct AVX2; // for f32 and f64 AVX is enough
pub struct AVX2FloatIgnoreNaN;
pub struct AVX2FloatReturnNaN;
Expand All @@ -32,6 +38,9 @@ impl SIMDInstructionSet for AVX2 {
const REGISTER_SIZE: usize = 256;
}

// SIMDInstructionSet only implented for AVX512
// -> necessary to auto implement SIMDCore & SIMDArgMinMax in generic.rs

pub struct AVX512;
pub struct AVX512FloatIgnoreNaN;
pub struct AVX512FloatReturnNaN;
Expand All @@ -42,6 +51,9 @@ impl SIMDInstructionSet for AVX512 {

// ----------------------------- aarch64 / arm -----------------------------

// SIMDInstructionSet only implented for NEON
// -> necessary to auto implement SIMDCore & SIMDArgMinMax in generic.rs

pub struct NEON;
pub struct NEONFloatIgnoreNaN;
pub struct NEONFloatReturnNaN;
Expand Down
76 changes: 50 additions & 26 deletions src/simd/generic.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use num_traits::{AsPrimitive, Float};
use num_traits::{AsPrimitive, Float, PrimInt};

use super::config::SIMDInstructionSet;
use super::task::*;
use crate::scalar::{ScalarArgMinMax, SCALAR};

Expand All @@ -12,6 +13,9 @@ where
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
{
/// The target feature for the SIMD operations
const TARGET_FEATURE: &'static str;

/// Integers > this value **cannot** be accurately represented in SIMDVecDtype
const MAX_INDEX: usize;
/// Initial index value for the SIMD vector
Expand Down Expand Up @@ -215,33 +219,33 @@ where
}
}

// Implement SIMDCore where SIMDOps is implemented
// Implement SIMDCore where SIMDOps is implemented for signed and unsigned integers (PrimInt)
impl<T, ScalarDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize>
SIMDCore<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE> for T
where
ScalarDType: Copy + PartialOrd + AsPrimitive<usize>,
ScalarDType: PrimInt + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
T: SIMDOps<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE>,
T: SIMDOps<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE> + SIMDInstructionSet,
{
// Use the default implementation
// Use the implementation
}

// --------------- Float Ignore NaNs

/// SIMD operations for setting a SIMD vector to a scalar value.
/// SIMD operations for setting a SIMD vector to a scalar value (only required for floats)
pub trait SIMDSetOps<ScalarDType, SIMDVecDtype>
where
ScalarDType: Copy,
ScalarDType: Float,
{
unsafe fn _mm_set1(a: ScalarDType) -> SIMDVecDtype;
}

/// SIMDCore trait that ignore NaNs (for float types)
/// SIMDCore trait that ignore NaNs (for floats)
pub trait SIMDCoreFloatIgnoreNaN<ScalarDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize>:
SIMDOps<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE> + SIMDSetOps<ScalarDType, SIMDVecDtype>
where
ScalarDType: Copy + PartialOrd + AsPrimitive<usize> + Float,
ScalarDType: Float + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
{
Expand Down Expand Up @@ -346,22 +350,22 @@ where
}
}

// Implement SIMDCoreFloatIgnoreNaNs where SIMDOps + SIMDSetOps is implemented
// Implement SIMDCoreFloatIgnoreNaNs where SIMDOps + SIMDSetOps is implemented for floats (Float)
impl<T, SCALARDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize>
SIMDCoreFloatIgnoreNaN<SCALARDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE> for T
where
SCALARDType: Copy + PartialOrd + AsPrimitive<usize> + Float,
SCALARDType: Float + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
T: SIMDOps<SCALARDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE>
+ SIMDSetOps<SCALARDType, SIMDVecDtype>,
{
// Use the default implementation
// Use the implementation
}

// --------------- Float Return NaNs

// IDEA: make SIMDOps extend this trait & provide empty default implementations
// IDEA: make SIMDOps extend this trait & provide empty implementations

// pub trait SIMDOrdTransformOps<FloatDType, IntDType, SIMDVecDTtype, const LANE_SIZE: usize>
// where
Expand All @@ -385,6 +389,10 @@ where

// ------------------------------- ArgMinMax SIMD TRAIT ------------------------------

// --------------- Default

/// The default SIMDArgMinMax trait (can be implemented for signed and unsigned integers + floats)
/// TODO: decide if we create a separate trait for floats when returning NaNs
#[allow(clippy::missing_safety_doc)] // TODO: add safety docs?
pub trait SIMDArgMinMax<ScalarDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize>:
SIMDCore<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE>
Expand All @@ -394,44 +402,60 @@ where
SIMDMaskDtype: Copy,
{
/// Returns the index of the minimum and maximum value in the array
unsafe fn argminmax(data: &[ScalarDType]) -> (usize, usize);

// Is necessary to have a separate function for this so we can call it in the
// argminmax function when we add the target feature to the function.
#[inline(always)]
unsafe fn _argminmax(data: &[ScalarDType]) -> (usize, usize)
#[target_feature(enable = Self::TARGET_FEATURE)]
unsafe fn argminmax(data: &[ScalarDType]) -> (usize, usize)
where
SCALAR: ScalarArgMinMax<ScalarDType>,
{
argminmax_generic(data, LANE_SIZE, Self::_overflow_safe_core_argminmax)
}
}

// Implement SIMDArgMinMax where SIMDCore is implemented for signed and unsigned integers (PrimInt)
impl<T, SCALARDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize>
SIMDArgMinMax<SCALARDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE> for T
where
SCALARDType: PrimInt + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
T: SIMDCore<SCALARDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE> + SIMDInstructionSet,
{
// Use the implementation
}

#[allow(clippy::missing_safety_doc)] // TODO: add safety docs?
pub trait SIMDArgMinMaxFloatIgnoreNaN<
ScalarDType,
SIMDVecDtype,
SIMDMaskDtype,
const LANE_SIZE: usize,
>: SIMDCoreFloatIgnoreNaN<ScalarDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE> where
ScalarDType: Copy + PartialOrd + AsPrimitive<usize> + Float,
ScalarDType: Float + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
{
/// Returns the index of the minimum and maximum value in the array
unsafe fn argminmax(data: &[ScalarDType]) -> (usize, usize);

// Is necessary to have a separate function for this so we can call it in the
// argminmax function when we add the target feature to the function.
#[inline(always)]
unsafe fn _argminmax(data: &[ScalarDType]) -> (usize, usize)
#[target_feature(enable = Self::TARGET_FEATURE)]
unsafe fn argminmax(data: &[ScalarDType]) -> (usize, usize)
where
SCALAR: ScalarArgMinMax<ScalarDType>,
{
argminmax_generic(data, LANE_SIZE, Self::_overflow_safe_core_argminmax)
}
}

// Implement SIMDArgMinMaxFloatIgnoreNaN where SIMDCoreFloatIgnoreNaN is implemented for floats (Float)
impl<T, SCALARDType, SIMDVecDtype, SIMDMaskDtype, const LANE_SIZE: usize>
SIMDArgMinMaxFloatIgnoreNaN<SCALARDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE> for T
where
SCALARDType: Float + AsPrimitive<usize>,
SIMDVecDtype: Copy,
SIMDMaskDtype: Copy,
T: SIMDCoreFloatIgnoreNaN<SCALARDType, SIMDVecDtype, SIMDMaskDtype, LANE_SIZE>,
{
// Use the implementation
}

// TODO: update this to use the new SIMD trait
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
macro_rules! unimplement_simd {
Expand Down
4 changes: 2 additions & 2 deletions src/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ pub use generic::*;
mod simd_f16_ignore_nans;
// mod simd_f16_return_nans;
mod simd_f32_ignore_nans;
mod simd_f32_return_nans;
// mod simd_f32_return_nans;
mod simd_f64_ignore_nans;
mod simd_f64_return_nans;
// mod simd_f64_return_nans;
// SIGNED INT
mod simd_i16;
mod simd_i32;
Expand Down
35 changes: 10 additions & 25 deletions src/simd/simd_f16_ignore_nans.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(feature = "half")]
use super::config::SIMDInstructionSet;
#[cfg(feature = "half")]
use super::generic::{SIMDArgMinMaxFloatIgnoreNaN, SIMDOps, SIMDSetOps};
use super::generic::{SIMDOps, SIMDSetOps};

#[cfg(feature = "half")]
#[cfg(target_arch = "aarch64")]
Expand Down Expand Up @@ -61,6 +61,8 @@ mod avx2 {
}

impl SIMDOps<f16, __m256i, __m256i, LANE_SIZE> for AVX2FloatIgnoreNaN {
const TARGET_FEATURE: &'static str = "avx2";

const INITIAL_INDEX: __m256i = unsafe {
std::mem::transmute([
0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16,
Expand Down Expand Up @@ -172,20 +174,13 @@ mod avx2 {
}
}

impl SIMDArgMinMaxFloatIgnoreNaN<f16, __m256i, __m256i, LANE_SIZE> for AVX2FloatIgnoreNaN {
#[target_feature(enable = "avx2")]
unsafe fn argminmax(data: &[f16]) -> (usize, usize) {
Self::_argminmax(data)
}
}

//----- TESTS -----

#[cfg(test)]
mod tests {
use super::AVX2FloatIgnoreNaN as AVX2;
use super::SIMDArgMinMaxFloatIgnoreNaN;
use crate::scalar::generic::scalar_argminmax;
use crate::simd::generic::SIMDArgMinMaxFloatIgnoreNaN;

use half::f16;

Expand Down Expand Up @@ -296,6 +291,8 @@ mod sse {
}

impl SIMDOps<f16, __m128i, __m128i, LANE_SIZE> for SSEFloatIgnoreNaN {
const TARGET_FEATURE: &'static str = "sse4.1";

const INITIAL_INDEX: __m128i =
unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) };
const INDEX_INCREMENT: __m128i =
Expand Down Expand Up @@ -399,20 +396,13 @@ mod sse {
}
}

impl SIMDArgMinMaxFloatIgnoreNaN<f16, __m128i, __m128i, LANE_SIZE> for SSEFloatIgnoreNaN {
#[target_feature(enable = "sse4.1")]
unsafe fn argminmax(data: &[f16]) -> (usize, usize) {
Self::_argminmax(data)
}
}

// ----------------------------------------- TESTS -----------------------------------------

#[cfg(test)]
mod tests {
use super::SIMDArgMinMaxFloatIgnoreNaN;
use super::SSEFloatIgnoreNaN as SSE;
use crate::scalar::generic::scalar_argminmax;
use crate::simd::generic::SIMDArgMinMaxFloatIgnoreNaN;

use half::f16;

Expand Down Expand Up @@ -507,6 +497,8 @@ mod avx512 {
}

impl SIMDOps<f16, __m512i, u32, LANE_SIZE> for AVX512FloatIgnoreNaN {
const TARGET_FEATURE: &'static str = "avx512bw";

const INITIAL_INDEX: __m512i = unsafe {
std::mem::transmute([
0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16,
Expand Down Expand Up @@ -623,20 +615,13 @@ mod avx512 {
}
}

impl SIMDArgMinMaxFloatIgnoreNaN<f16, __m512i, u32, LANE_SIZE> for AVX512FloatIgnoreNaN {
#[target_feature(enable = "avx512bw")]
unsafe fn argminmax(data: &[f16]) -> (usize, usize) {
Self::_argminmax(data)
}
}

// ----------------------------------------- TESTS -----------------------------------------

#[cfg(test)]
mod tests {
use super::AVX512FloatIgnoreNaN as AVX512;
use super::SIMDArgMinMaxFloatIgnoreNaN;
use crate::scalar::generic::scalar_argminmax;
use crate::simd::generic::SIMDArgMinMaxFloatIgnoreNaN;

use half::f16;

Expand Down
Loading

0 comments on commit 07a5e66

Please sign in to comment.