diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 457729aaa..2baadd8bc 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -2,7 +2,7 @@ name: Rust on: push: - branches: ["main", "vchord"] + branches: ["vchord"] paths: - ".cargo/**" - ".github/**" @@ -16,7 +16,7 @@ on: - "vectors.control" - "vendor/**" pull_request: - branches: ["main", "vchord"] + branches: ["vchord"] paths: - ".cargo/**" - ".github/**" @@ -60,6 +60,8 @@ jobs: sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' sudo apt-get purge -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' sudo apt-get update + curl --proto '=https' --tlsv1.2 -sSf https://apt.llvm.org/llvm.sh | sudo bash -s -- 18 + sudo update-alternatives --install /usr/bin/clang clang $(which clang-18) 255 sudo apt-get install -y build-essential crossbuild-essential-arm64 sudo apt-get install -y qemu-user-static touch ~/.cargo/config.toml @@ -81,33 +83,9 @@ jobs: - name: Clippy run: | cargo clippy --workspace --exclude pgvectors --exclude pyvectors --target $ARCH-unknown-linux-gnu - export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg14_${ARCH}_debian/pg_config/pg_config - export PGRX_TARGET_INFO_PATH_PG14=$(pwd)/vendor/pg14_${ARCH}_debian/pgrx_binding - cargo clippy --package pgvectors --features pg14 --no-deps --target $ARCH-unknown-linux-gnu - export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg15_${ARCH}_debian/pg_config/pg_config - export PGRX_TARGET_INFO_PATH_PG15=$(pwd)/vendor/pg15_${ARCH}_debian/pgrx_binding - cargo clippy --package pgvectors --features pg15 --no-deps --target $ARCH-unknown-linux-gnu - export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg16_${ARCH}_debian/pg_config/pg_config - export PGRX_TARGET_INFO_PATH_PG16=$(pwd)/vendor/pg16_${ARCH}_debian/pgrx_binding - cargo clippy --package pgvectors --features pg16 --no-deps --target $ARCH-unknown-linux-gnu - export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg17_${ARCH}_debian/pg_config/pg_config - export PGRX_TARGET_INFO_PATH_PG17=$(pwd)/vendor/pg17_${ARCH}_debian/pgrx_binding - cargo clippy --package pgvectors --features pg17 --no-deps --target $ARCH-unknown-linux-gnu - name: Build run: | cargo build --workspace --exclude pgvectors --exclude pyvectors --target $ARCH-unknown-linux-gnu - export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg14_${ARCH}_debian/pg_config/pg_config - export PGRX_TARGET_INFO_PATH_PG14=$(pwd)/vendor/pg14_${ARCH}_debian/pgrx_binding - cargo build --package pgvectors --lib --features pg14 --target $ARCH-unknown-linux-gnu - export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg15_${ARCH}_debian/pg_config/pg_config - export PGRX_TARGET_INFO_PATH_PG15=$(pwd)/vendor/pg15_${ARCH}_debian/pgrx_binding - cargo build --package pgvectors --lib --features pg15 --target $ARCH-unknown-linux-gnu - export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg16_${ARCH}_debian/pg_config/pg_config - export PGRX_TARGET_INFO_PATH_PG16=$(pwd)/vendor/pg16_${ARCH}_debian/pgrx_binding - cargo build --package pgvectors --lib --features pg16 --target $ARCH-unknown-linux-gnu - export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg17_${ARCH}_debian/pg_config/pg_config - export PGRX_TARGET_INFO_PATH_PG17=$(pwd)/vendor/pg17_${ARCH}_debian/pgrx_binding - cargo build --package pgvectors --lib --features pg17 --target $ARCH-unknown-linux-gnu - name: Test run: | cargo test --workspace --exclude pgvectors --exclude pyvectors --no-fail-fast --target $ARCH-unknown-linux-gnu diff --git a/Cargo.lock b/Cargo.lock index b9abb1257..9d8d94e8f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -428,7 +428,7 @@ name = "base" version = "0.0.0" dependencies = [ "base_macros", - "detect", + "cc", "half 2.4.1", "libc", "rand", @@ -586,9 +586,12 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.8" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504bdec147f2cc13c8b57ed9401fd8a147cc66b67ad5cb241394244f2c947549" +checksum = "27f657647bcff5394bf56c7317665bbf790a137a50eaaa5c6bfbb9e27a518f2d" +dependencies = [ + "shlex", +] [[package]] name = "cee-scape" @@ -824,22 +827,6 @@ dependencies = [ "parking_lot_core", ] -[[package]] -name = "detect" -version = "0.0.0" -dependencies = [ - "detect_macros", -] - -[[package]] -name = "detect_macros" -version = "0.0.0" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.72", -] - [[package]] name = "dirs-next" version = "2.0.0" diff --git a/crates/base/Cargo.toml b/crates/base/Cargo.toml index 2cf515d7f..f8ca62354 100644 --- a/crates/base/Cargo.toml +++ b/crates/base/Cargo.toml @@ -13,7 +13,9 @@ toml.workspace = true validator.workspace = true base_macros = { path = "../base_macros" } -detect = { path = "../detect" } + +[build-dependencies] +cc = "1.2.3" [lints] workspace = true diff --git a/crates/base/build.rs b/crates/base/build.rs new file mode 100644 index 000000000..c8ea529f1 --- /dev/null +++ b/crates/base/build.rs @@ -0,0 +1,8 @@ +fn main() { + println!("cargo::rerun-if-changed=cshim.c"); + cc::Build::new() + .compiler("clang") + .file("cshim.c") + .opt_level(3) + .compile("base_cshim"); +} diff --git a/crates/base/cshim.c b/crates/base/cshim.c new file mode 100644 index 000000000..d9d9e18b4 --- /dev/null +++ b/crates/base/cshim.c @@ -0,0 +1,219 @@ +#include +#include + +#ifdef __aarch64__ + +#include +#include + +__attribute__((target("v8.3a,fp16"))) float +fp16_reduce_sum_of_d2_v8_3a_fp16_unroll(__fp16 *__restrict a, + __fp16 *__restrict b, size_t n) { + float16x8_t d2_0 = vdupq_n_f16(0.0); + float16x8_t d2_1 = vdupq_n_f16(0.0); + float16x8_t d2_2 = vdupq_n_f16(0.0); + float16x8_t d2_3 = vdupq_n_f16(0.0); + while (n >= 32) { + float16x8_t x_0 = vld1q_f16(a + 0); + float16x8_t x_1 = vld1q_f16(a + 8); + float16x8_t x_2 = vld1q_f16(a + 16); + float16x8_t x_3 = vld1q_f16(a + 24); + float16x8_t y_0 = vld1q_f16(b + 0); + float16x8_t y_1 = vld1q_f16(b + 8); + float16x8_t y_2 = vld1q_f16(b + 16); + float16x8_t y_3 = vld1q_f16(b + 24); + a += 32; + b += 32; + n -= 32; + float16x8_t d_0 = vsubq_f16(x_0, y_0); + float16x8_t d_1 = vsubq_f16(x_1, y_1); + float16x8_t d_2 = vsubq_f16(x_2, y_2); + float16x8_t d_3 = vsubq_f16(x_3, y_3); + d2_0 = vfmaq_f16(d2_0, d_0, d_0); + d2_1 = vfmaq_f16(d2_1, d_1, d_1); + d2_2 = vfmaq_f16(d2_2, d_2, d_2); + d2_3 = vfmaq_f16(d2_3, d_3, d_3); + } + if (n > 0) { + __fp16 A[32] = {}; + __fp16 B[32] = {}; + for (size_t i = 0; i < n; i += 1) { + A[i] = a[i]; + B[i] = b[i]; + } + float16x8_t x_0 = vld1q_f16(A + 0); + float16x8_t x_1 = vld1q_f16(A + 8); + float16x8_t x_2 = vld1q_f16(A + 16); + float16x8_t x_3 = vld1q_f16(A + 24); + float16x8_t y_0 = vld1q_f16(B + 0); + float16x8_t y_1 = vld1q_f16(B + 8); + float16x8_t y_2 = vld1q_f16(B + 16); + float16x8_t y_3 = vld1q_f16(B + 24); + float16x8_t d_0 = vsubq_f16(x_0, y_0); + float16x8_t d_1 = vsubq_f16(x_1, y_1); + float16x8_t d_2 = vsubq_f16(x_2, y_2); + float16x8_t d_3 = vsubq_f16(x_3, y_3); + d2_0 = vfmaq_f16(d2_0, d_0, d_0); + d2_1 = vfmaq_f16(d2_1, d_1, d_1); + d2_2 = vfmaq_f16(d2_2, d_2, d_2); + d2_3 = vfmaq_f16(d2_3, d_3, d_3); + } + float16x8_t d2 = vaddq_f16(vaddq_f16(d2_0, d2_1), vaddq_f16(d2_2, d2_3)); + return vgetq_lane_f16(d2, 0) + vgetq_lane_f16(d2, 1) + vgetq_lane_f16(d2, 2) + + vgetq_lane_f16(d2, 3) + vgetq_lane_f16(d2, 4) + vgetq_lane_f16(d2, 5) + + vgetq_lane_f16(d2, 6) + vgetq_lane_f16(d2, 7); +} + +__attribute__((target("v8.3a,fp16"))) float +fp16_reduce_sum_of_xy_v8_3a_fp16_unroll(__fp16 *__restrict a, + __fp16 *__restrict b, size_t n) { + float16x8_t xy_0 = vdupq_n_f16(0.0); + float16x8_t xy_1 = vdupq_n_f16(0.0); + float16x8_t xy_2 = vdupq_n_f16(0.0); + float16x8_t xy_3 = vdupq_n_f16(0.0); + while (n >= 32) { + float16x8_t x_0 = vld1q_f16(a + 0); + float16x8_t x_1 = vld1q_f16(a + 8); + float16x8_t x_2 = vld1q_f16(a + 16); + float16x8_t x_3 = vld1q_f16(a + 24); + float16x8_t y_0 = vld1q_f16(b + 0); + float16x8_t y_1 = vld1q_f16(b + 8); + float16x8_t y_2 = vld1q_f16(b + 16); + float16x8_t y_3 = vld1q_f16(b + 24); + a += 32; + b += 32; + n -= 32; + xy_0 = vfmaq_f16(xy_0, x_0, y_0); + xy_1 = vfmaq_f16(xy_1, x_1, y_1); + xy_2 = vfmaq_f16(xy_2, x_2, y_2); + xy_3 = vfmaq_f16(xy_3, x_3, y_3); + } + if (n > 0) { + __fp16 A[32] = {}; + __fp16 B[32] = {}; + for (size_t i = 0; i < n; i += 1) { + A[i] = a[i]; + B[i] = b[i]; + } + float16x8_t x_0 = vld1q_f16(A + 0); + float16x8_t x_1 = vld1q_f16(A + 8); + float16x8_t x_2 = vld1q_f16(A + 16); + float16x8_t x_3 = vld1q_f16(A + 24); + float16x8_t y_0 = vld1q_f16(B + 0); + float16x8_t y_1 = vld1q_f16(B + 8); + float16x8_t y_2 = vld1q_f16(B + 16); + float16x8_t y_3 = vld1q_f16(B + 24); + xy_0 = vfmaq_f16(xy_0, x_0, y_0); + xy_1 = vfmaq_f16(xy_1, x_1, y_1); + xy_2 = vfmaq_f16(xy_2, x_2, y_2); + xy_3 = vfmaq_f16(xy_3, x_3, y_3); + } + float16x8_t xy = vaddq_f16(vaddq_f16(xy_0, xy_1), vaddq_f16(xy_2, xy_3)); + return vgetq_lane_f16(xy, 0) + vgetq_lane_f16(xy, 1) + vgetq_lane_f16(xy, 2) + + vgetq_lane_f16(xy, 3) + vgetq_lane_f16(xy, 4) + vgetq_lane_f16(xy, 5) + + vgetq_lane_f16(xy, 6) + vgetq_lane_f16(xy, 7); +} + +__attribute__((target("v8.3a,sve"))) float +fp16_reduce_sum_of_d2_v8_3a_sve(__fp16 *__restrict a, __fp16 *__restrict b, + size_t n) { + svfloat16_t d2 = svdup_f16(0.0); + for (size_t i = 0; i < n; i += svcnth()) { + svbool_t mask = svwhilelt_b16(i, n); + svfloat16_t x = svld1_f16(mask, a + i); + svfloat16_t y = svld1_f16(mask, b + i); + svfloat16_t d = svsub_f16_x(mask, x, y); + d2 = svmla_f16_x(mask, d2, d, d); + } + return svaddv_f16(svptrue_b16(), d2); +} + +__attribute__((target("v8.3a,sve"))) float +fp16_reduce_sum_of_xy_v8_3a_sve(__fp16 *__restrict a, __fp16 *__restrict b, + size_t n) { + svfloat16_t xy = svdup_f16(0.0); + for (size_t i = 0; i < n; i += svcnth()) { + svbool_t mask = svwhilelt_b16(i, n); + svfloat16_t x = svld1_f16(mask, a + i); + svfloat16_t y = svld1_f16(mask, b + i); + xy = svmla_f16_x(mask, xy, x, y); + } + return svaddv_f16(svptrue_b16(), xy); +} + +__attribute__((target("v8.3a,sve"))) float +fp32_reduce_sum_of_x_v8_3a_sve(float *__restrict this, size_t n) { + svfloat32_t sum = svdup_f32(0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, this + i); + sum = svadd_f32_x(mask, sum, x); + } + return svaddv_f32(svptrue_b32(), sum); +} + +__attribute__((target("v8.3a,sve"))) float +fp32_reduce_sum_of_abs_x_v8_3a_sve(float *__restrict this, size_t n) { + svfloat32_t sum = svdup_f32(0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, this + i); + sum = svadd_f32_x(mask, sum, svabs_f32_x(mask, x)); + } + return svaddv_f32(svptrue_b32(), sum); +} + +__attribute__((target("v8.3a,sve"))) float +fp32_reduce_sum_of_x2_v8_3a_sve(float *__restrict this, size_t n) { + svfloat32_t sum = svdup_f32(0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, this + i); + sum = svmla_f32_x(mask, sum, x, x); + } + return svaddv_f32(svptrue_b32(), sum); +} + +__attribute__((target("v8.3a,sve"))) void +fp32_reduce_min_max_of_x_v8_3a_sve(float *__restrict this, size_t n, + float *out_min, float *out_max) { + svfloat32_t min = svdup_f32(1.0 / 0.0); + svfloat32_t max = svdup_f32(-1.0 / 0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, this + i); + min = svmin_f32_x(mask, min, x); + max = svmax_f32_x(mask, max, x); + } + *out_min = svminv_f32(svptrue_b32(), min); + *out_max = svmaxv_f32(svptrue_b32(), max); +} + +__attribute__((target("v8.3a,sve"))) float +fp32_reduce_sum_of_xy_v8_3a_sve(float *__restrict lhs, float *__restrict rhs, + size_t n) { + svfloat32_t sum = svdup_f32(0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, lhs + i); + svfloat32_t y = svld1_f32(mask, rhs + i); + sum = svmla_f32_x(mask, sum, x, y); + } + return svaddv_f32(svptrue_b32(), sum); +} + +__attribute__((target("v8.3a,sve"))) float +fp32_reduce_sum_of_d2_v8_3a_sve(float *__restrict lhs, float *__restrict rhs, + size_t n) { + svfloat32_t sum = svdup_f32(0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, lhs + i); + svfloat32_t y = svld1_f32(mask, rhs + i); + svfloat32_t d = svsub_f32_x(mask, x, y); + sum = svmla_f32_x(mask, sum, d, d); + } + return svaddv_f32(svptrue_b32(), sum); +} + +#endif diff --git a/crates/base/src/lib.rs b/crates/base/src/lib.rs index 2e78cdbe1..a04ed3a83 100644 --- a/crates/base/src/lib.rs +++ b/crates/base/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(target_feature_11)] #![feature(avx512_target_feature)] #![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] #![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512_f16))] diff --git a/crates/base/src/simd/bit.rs b/crates/base/src/simd/bit.rs index ec4ddb730..d1809fa3e 100644 --- a/crates/base/src/simd/bit.rs +++ b/crates/base/src/simd/bit.rs @@ -8,8 +8,9 @@ mod sum_of_and { #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - unsafe fn sum_of_and_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { + #[crate::simd::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512vpopcntdq")] + fn sum_of_and_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -38,21 +39,22 @@ mod sum_of_and { #[cfg(all(target_arch = "x86_64", test))] #[test] fn sum_of_and_v4_avx512vpopcntdq_test() { - detect::init(); - if !detect::v4_avx512vpopcntdq::detect() { - println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + if !crate::simd::is_cpu_detected!("v4") + || !crate::simd::is_feature_detected!("avx512vpopcntdq") + { + println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); return; } for _ in 0..300 { let lhs = (0..126).map(|_| rand::random::()).collect::>(); let rhs = (0..126).map(|_| rand::random::()).collect::>(); let specialized = unsafe { sum_of_and_v4_avx512vpopcntdq(&lhs, &rhs) }; - let fallback = unsafe { sum_of_and_fallback(&lhs, &rhs) }; + let fallback = sum_of_and_fallback(&lhs, &rhs); assert_eq!(specialized, fallback); } } - #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] pub fn sum_of_and(lhs: &[u64], rhs: &[u64]) -> u32 { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -74,8 +76,9 @@ mod sum_of_or { #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - unsafe fn sum_of_or_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { + #[crate::simd::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512vpopcntdq")] + fn sum_of_or_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -104,21 +107,22 @@ mod sum_of_or { #[cfg(all(target_arch = "x86_64", test))] #[test] fn sum_of_or_v4_avx512vpopcntdq_test() { - detect::init(); - if !detect::v4_avx512vpopcntdq::detect() { - println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + if !crate::simd::is_cpu_detected!("v4") + || !crate::simd::is_feature_detected!("avx512vpopcntdq") + { + println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); return; } for _ in 0..300 { let lhs = (0..126).map(|_| rand::random::()).collect::>(); let rhs = (0..126).map(|_| rand::random::()).collect::>(); let specialized = unsafe { sum_of_or_v4_avx512vpopcntdq(&lhs, &rhs) }; - let fallback = unsafe { sum_of_or_fallback(&lhs, &rhs) }; + let fallback = sum_of_or_fallback(&lhs, &rhs); assert_eq!(specialized, fallback); } } - #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] pub fn sum_of_or(lhs: &[u64], rhs: &[u64]) -> u32 { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -140,8 +144,9 @@ mod sum_of_xor { #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - unsafe fn sum_of_xor_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { + #[crate::simd::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512vpopcntdq")] + fn sum_of_xor_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -170,21 +175,22 @@ mod sum_of_xor { #[cfg(all(target_arch = "x86_64", test))] #[test] fn sum_of_xor_v4_avx512vpopcntdq_test() { - detect::init(); - if !detect::v4_avx512vpopcntdq::detect() { - println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + if !crate::simd::is_cpu_detected!("v4") + || !crate::simd::is_feature_detected!("avx512vpopcntdq") + { + println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); return; } for _ in 0..300 { let lhs = (0..126).map(|_| rand::random::()).collect::>(); let rhs = (0..126).map(|_| rand::random::()).collect::>(); let specialized = unsafe { sum_of_xor_v4_avx512vpopcntdq(&lhs, &rhs) }; - let fallback = unsafe { sum_of_xor_fallback(&lhs, &rhs) }; + let fallback = sum_of_xor_fallback(&lhs, &rhs); assert_eq!(specialized, fallback); } } - #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] pub fn sum_of_xor(lhs: &[u64], rhs: &[u64]) -> u32 { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -206,8 +212,9 @@ mod sum_of_and_or { #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - unsafe fn sum_of_and_or_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> (u32, u32) { + #[crate::simd::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512vpopcntdq")] + fn sum_of_and_or_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> (u32, u32) { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -242,21 +249,22 @@ mod sum_of_and_or { #[cfg(all(target_arch = "x86_64", test))] #[test] fn sum_of_xor_v4_avx512vpopcntdq_test() { - detect::init(); - if !detect::v4_avx512vpopcntdq::detect() { - println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + if !crate::simd::is_cpu_detected!("v4") + || !crate::simd::is_feature_detected!("avx512vpopcntdq") + { + println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); return; } for _ in 0..300 { let lhs = (0..126).map(|_| rand::random::()).collect::>(); let rhs = (0..126).map(|_| rand::random::()).collect::>(); let specialized = unsafe { sum_of_and_or_v4_avx512vpopcntdq(&lhs, &rhs) }; - let fallback = unsafe { sum_of_and_or_fallback(&lhs, &rhs) }; + let fallback = sum_of_and_or_fallback(&lhs, &rhs); assert_eq!(specialized, fallback); } } - #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] pub fn sum_of_and_or(lhs: &[u64], rhs: &[u64]) -> (u32, u32) { assert_eq!(lhs.len(), rhs.len()); let n = lhs.len(); @@ -280,8 +288,9 @@ mod sum_of_x { #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - unsafe fn sum_of_x_v4_avx512vpopcntdq(this: &[u64]) -> u32 { + #[crate::simd::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512vpopcntdq")] + fn sum_of_x_v4_avx512vpopcntdq(this: &[u64]) -> u32 { unsafe { use std::arch::x86_64::*; let mut and = _mm512_setzero_si512(); @@ -305,20 +314,21 @@ mod sum_of_x { #[cfg(all(target_arch = "x86_64", test))] #[test] fn sum_of_x_v4_avx512vpopcntdq_test() { - detect::init(); - if !detect::v4_avx512vpopcntdq::detect() { - println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + if !crate::simd::is_cpu_detected!("v4") + || !crate::simd::is_feature_detected!("avx512vpopcntdq") + { + println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); return; } for _ in 0..300 { let this = (0..126).map(|_| rand::random::()).collect::>(); let specialized = unsafe { sum_of_x_v4_avx512vpopcntdq(&this) }; - let fallback = unsafe { sum_of_x_fallback(&this) }; + let fallback = sum_of_x_fallback(&this); assert_eq!(specialized, fallback); } } - #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] pub fn sum_of_x(this: &[u64]) -> u32 { let n = this.len(); let mut and = 0; @@ -329,50 +339,71 @@ mod sum_of_x { } } -#[detect::multiversion(v4, v3, v2, neon, fallback)] +#[inline(always)] pub fn vector_and(lhs: &[u64], rhs: &[u64]) -> Vec { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { + vector_and::vector_and(lhs, rhs) +} + +mod vector_and { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_and(lhs: &[u64], rhs: &[u64]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] & rhs[i]); + } + } unsafe { - r.as_mut_ptr().add(i).write(lhs[i] & rhs[i]); + r.set_len(n); } + r } - unsafe { - r.set_len(n); - } - r } -#[detect::multiversion(v4, v3, v2, neon, fallback)] +#[inline(always)] pub fn vector_or(lhs: &[u64], rhs: &[u64]) -> Vec { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { + vector_or::vector_or(lhs, rhs) +} + +mod vector_or { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_or(lhs: &[u64], rhs: &[u64]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] | rhs[i]); + } + } unsafe { - r.as_mut_ptr().add(i).write(lhs[i] | rhs[i]); + r.set_len(n); } + r } - unsafe { - r.set_len(n); - } - r } -#[detect::multiversion(v4, v3, v2, neon, fallback)] +#[inline(always)] pub fn vector_xor(lhs: &[u64], rhs: &[u64]) -> Vec { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { + vector_xor::vector_xor(lhs, rhs) +} + +mod vector_xor { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_xor(lhs: &[u64], rhs: &[u64]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] ^ rhs[i]); + } + } unsafe { - r.as_mut_ptr().add(i).write(lhs[i] ^ rhs[i]); + r.set_len(n); } + r } - unsafe { - r.set_len(n); - } - r } diff --git a/crates/base/src/simd/emulate.rs b/crates/base/src/simd/emulate.rs index 60eb2c423..ed74d6645 100644 --- a/crates/base/src/simd/emulate.rs +++ b/crates/base/src/simd/emulate.rs @@ -5,8 +5,8 @@ // Instructions. arXiv preprint arXiv:2112.06342. #[inline] #[cfg(target_arch = "x86_64")] -#[detect::target_cpu(enable = "v4")] -pub unsafe fn emulate_mm512_2intersect_epi32( +#[crate::simd::target_cpu(enable = "v4")] +pub fn emulate_mm512_2intersect_epi32( a: std::arch::x86_64::__m512i, b: std::arch::x86_64::__m512i, ) -> (std::arch::x86_64::__mmask16, std::arch::x86_64::__mmask16) { @@ -62,8 +62,8 @@ pub unsafe fn emulate_mm512_2intersect_epi32( #[inline] #[cfg(target_arch = "x86_64")] -#[detect::target_cpu(enable = "v3")] -pub unsafe fn emulate_mm256_reduce_add_ps(mut x: std::arch::x86_64::__m256) -> f32 { +#[crate::simd::target_cpu(enable = "v3")] +pub fn emulate_mm256_reduce_add_ps(mut x: std::arch::x86_64::__m256) -> f32 { unsafe { use std::arch::x86_64::*; x = _mm256_add_ps(x, _mm256_permute2f128_ps(x, x, 1)); @@ -75,8 +75,20 @@ pub unsafe fn emulate_mm256_reduce_add_ps(mut x: std::arch::x86_64::__m256) -> f #[inline] #[cfg(target_arch = "x86_64")] -#[detect::target_cpu(enable = "v4")] -pub unsafe fn emulate_mm512_reduce_add_epi16(x: std::arch::x86_64::__m512i) -> i16 { +#[crate::simd::target_cpu(enable = "v2")] +pub fn emulate_mm_reduce_add_ps(mut x: std::arch::x86_64::__m128) -> f32 { + unsafe { + use std::arch::x86_64::*; + x = _mm_hadd_ps(x, x); + x = _mm_hadd_ps(x, x); + _mm_cvtss_f32(x) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::simd::target_cpu(enable = "v4")] +pub fn emulate_mm512_reduce_add_epi16(x: std::arch::x86_64::__m512i) -> i16 { unsafe { use std::arch::x86_64::*; _mm256_reduce_add_epi16(_mm512_castsi512_si256(x)) @@ -86,8 +98,8 @@ pub unsafe fn emulate_mm512_reduce_add_epi16(x: std::arch::x86_64::__m512i) -> i #[inline] #[cfg(target_arch = "x86_64")] -#[detect::target_cpu(enable = "v3")] -pub unsafe fn emulate_mm256_reduce_add_epi16(mut x: std::arch::x86_64::__m256i) -> i16 { +#[crate::simd::target_cpu(enable = "v3")] +pub fn emulate_mm256_reduce_add_epi16(mut x: std::arch::x86_64::__m256i) -> i16 { unsafe { use std::arch::x86_64::*; x = _mm256_add_epi16(x, _mm256_permute2f128_si256(x, x, 1)); @@ -100,8 +112,21 @@ pub unsafe fn emulate_mm256_reduce_add_epi16(mut x: std::arch::x86_64::__m256i) #[inline] #[cfg(target_arch = "x86_64")] -#[detect::target_cpu(enable = "v3")] -pub unsafe fn emulate_mm256_reduce_add_epi32(mut x: std::arch::x86_64::__m256i) -> i32 { +#[crate::simd::target_cpu(enable = "v2")] +pub fn emulate_mm_reduce_add_epi16(mut x: std::arch::x86_64::__m128i) -> i16 { + unsafe { + use std::arch::x86_64::*; + x = _mm_hadd_epi16(x, x); + x = _mm_hadd_epi16(x, x); + let x = _mm_cvtsi128_si32(x); + (x as i16) + ((x >> 16) as i16) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::simd::target_cpu(enable = "v3")] +pub fn emulate_mm256_reduce_add_epi32(mut x: std::arch::x86_64::__m256i) -> i32 { unsafe { use std::arch::x86_64::*; x = _mm256_add_epi32(x, _mm256_permute2f128_si256(x, x, 1)); @@ -113,8 +138,20 @@ pub unsafe fn emulate_mm256_reduce_add_epi32(mut x: std::arch::x86_64::__m256i) #[inline] #[cfg(target_arch = "x86_64")] -#[detect::target_cpu(enable = "v3")] -pub unsafe fn emulate_mm256_reduce_min_ps(x: std::arch::x86_64::__m256) -> f32 { +#[crate::simd::target_cpu(enable = "v2")] +pub fn emulate_mm_reduce_add_epi32(mut x: std::arch::x86_64::__m128i) -> i32 { + unsafe { + use std::arch::x86_64::*; + x = _mm_hadd_epi32(x, x); + x = _mm_hadd_epi32(x, x); + _mm_cvtsi128_si32(x) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::simd::target_cpu(enable = "v3")] +pub fn emulate_mm256_reduce_min_ps(x: std::arch::x86_64::__m256) -> f32 { use crate::aligned::Aligned16; unsafe { use std::arch::x86_64::*; @@ -129,8 +166,22 @@ pub unsafe fn emulate_mm256_reduce_min_ps(x: std::arch::x86_64::__m256) -> f32 { #[inline] #[cfg(target_arch = "x86_64")] -#[detect::target_cpu(enable = "v3")] -pub unsafe fn emulate_mm256_reduce_max_ps(x: std::arch::x86_64::__m256) -> f32 { +#[crate::simd::target_cpu(enable = "v2")] +pub fn emulate_mm_reduce_min_ps(x: std::arch::x86_64::__m128) -> f32 { + use crate::aligned::Aligned16; + unsafe { + use std::arch::x86_64::*; + let min = x; + let mut x = Aligned16([0.0f32; 4]); + _mm_store_ps(x.0.as_mut_ptr(), min); + f32::min(f32::min(x.0[0], x.0[1]), f32::min(x.0[2], x.0[3])) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::simd::target_cpu(enable = "v3")] +pub fn emulate_mm256_reduce_max_ps(x: std::arch::x86_64::__m256) -> f32 { use crate::aligned::Aligned16; unsafe { use std::arch::x86_64::*; @@ -142,3 +193,17 @@ pub unsafe fn emulate_mm256_reduce_max_ps(x: std::arch::x86_64::__m256) -> f32 { f32::max(f32::max(x.0[0], x.0[1]), f32::max(x.0[2], x.0[3])) } } + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::simd::target_cpu(enable = "v2")] +pub fn emulate_mm_reduce_max_ps(x: std::arch::x86_64::__m128) -> f32 { + use crate::aligned::Aligned16; + unsafe { + use std::arch::x86_64::*; + let max = x; + let mut x = Aligned16([0.0f32; 4]); + _mm_store_ps(x.0.as_mut_ptr(), max); + f32::max(f32::max(x.0[0], x.0[1]), f32::max(x.0[2], x.0[3])) + } +} diff --git a/crates/base/src/simd/f16.rs b/crates/base/src/simd/f16.rs index cf7fb04c6..3cecf3cb8 100644 --- a/crates/base/src/simd/f16.rs +++ b/crates/base/src/simd/f16.rs @@ -57,61 +57,29 @@ impl ScalarLike for f16 { f16::to_f32(self) } - // FIXME: add manually-implemented SIMD version - #[detect::multiversion(v4, v3, v2, neon, fallback)] + #[inline(always)] fn reduce_or_of_is_zero(this: &[f16]) -> bool { - for &x in this { - if x == f16::ZERO { - return true; - } - } - false + reduce_or_of_is_zero::reduce_or_of_is_zero(this) } - // FIXME: add manually-implemented SIMD version - #[detect::multiversion(v4, v3, v2, neon, fallback)] + #[inline(always)] fn reduce_sum_of_x(this: &[f16]) -> f32 { - let n = this.len(); - let mut x = 0.0f32; - for i in 0..n { - x += this[i].to_f32(); - } - x + reduce_sum_of_x::reduce_sum_of_x(this) } - // FIXME: add manually-implemented SIMD version - #[detect::multiversion(v4, v3, v2, neon, fallback)] + #[inline(always)] fn reduce_sum_of_abs_x(this: &[f16]) -> f32 { - let n = this.len(); - let mut x = 0.0f32; - for i in 0..n { - x += this[i].to_f32().abs(); - } - x + reduce_sum_of_abs_x::reduce_sum_of_abs_x(this) } - // FIXME: add manually-implemented SIMD version - #[detect::multiversion(v4, v3, v2, neon, fallback)] + #[inline(always)] fn reduce_sum_of_x2(this: &[f16]) -> f32 { - let n = this.len(); - let mut x2 = 0.0f32; - for i in 0..n { - x2 += this[i].to_f32() * this[i].to_f32(); - } - x2 + reduce_sum_of_x2::reduce_sum_of_x2(this) } - // FIXME: add manually-implemented SIMD version - #[detect::multiversion(v4, v3, v2, neon, fallback)] + #[inline(always)] fn reduce_min_max_of_x(this: &[f16]) -> (f32, f32) { - let mut min = f32::INFINITY; - let mut max = f32::NEG_INFINITY; - let n = this.len(); - for i in 0..n { - min = min.min(this[i].to_f32()); - max = max.max(this[i].to_f32()); - } - (min, max) + reduce_min_max_of_x::reduce_min_max_of_x(this) } #[inline(always)] @@ -134,151 +102,152 @@ impl ScalarLike for f16 { reduce_sum_of_sparse_d2::reduce_sum_of_sparse_d2(lidx, lval, ridx, rval) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn vector_add(lhs: &[f16], rhs: &[f16]) -> Vec { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { - unsafe { - r.as_mut_ptr().add(i).write(lhs[i] + rhs[i]); - } - } - unsafe { - r.set_len(n); - } - r + #[inline(always)] + fn vector_from_f32(this: &[f32]) -> Vec { + vector_from_f32::vector_from_f32(this) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn vector_add_inplace(lhs: &mut [f16], rhs: &[f16]) { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - for i in 0..n { - lhs[i] += rhs[i]; - } + #[inline(always)] + fn vector_to_f32(this: &[Self]) -> Vec { + vector_to_f32::vector_to_f32(this) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn vector_sub(lhs: &[f16], rhs: &[f16]) -> Vec { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { - unsafe { - r.as_mut_ptr().add(i).write(lhs[i] - rhs[i]); - } - } - unsafe { - r.set_len(n); - } - r + #[inline(always)] + fn vector_add(lhs: &[Self], rhs: &[Self]) -> Vec { + vector_add::vector_add(lhs, rhs) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn vector_mul(lhs: &[f16], rhs: &[f16]) -> Vec { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { - unsafe { - r.as_mut_ptr().add(i).write(lhs[i] * rhs[i]); - } - } - unsafe { - r.set_len(n); - } - r + #[inline(always)] + fn vector_add_inplace(lhs: &mut [Self], rhs: &[Self]) { + vector_add_inplace::vector_add_inplace(lhs, rhs) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn vector_mul_scalar(lhs: &[f16], rhs: f32) -> Vec { - let rhs = f16::from_f32(rhs); - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { - unsafe { - r.as_mut_ptr().add(i).write(lhs[i] * rhs); - } - } - unsafe { - r.set_len(n); - } - r + #[inline(always)] + fn vector_sub(lhs: &[Self], rhs: &[Self]) -> Vec { + vector_sub::vector_sub(lhs, rhs) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn vector_mul_scalar_inplace(lhs: &mut [f16], rhs: f32) { - let rhs = f16::from_f32(rhs); - let n = lhs.len(); - for i in 0..n { - lhs[i] *= rhs; + #[inline(always)] + fn vector_mul(lhs: &[Self], rhs: &[Self]) -> Vec { + vector_mul::vector_mul(lhs, rhs) + } + + #[inline(always)] + fn vector_mul_scalar(lhs: &[Self], rhs: f32) -> Vec { + vector_mul_scalar::vector_mul_scalar(lhs, rhs) + } + + #[inline(always)] + fn vector_mul_scalar_inplace(lhs: &mut [Self], rhs: f32) { + vector_mul_scalar_inplace::vector_mul_scalar_inplace(lhs, rhs) + } + + #[inline(always)] + fn vector_abs_inplace(this: &mut [Self]) { + vector_abs_inplace::vector_abs_inplace(this) + } + + #[inline(always)] + fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]> { + Self::vector_to_f32(this) + } + + #[inline(always)] + fn kmeans_helper(this: &mut [Self], x: f32, y: f32) { + kmeans_helper::kmeans_helper(this, x, y) + } +} + +mod reduce_or_of_is_zero { + // FIXME: add manually-implemented SIMD version + + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_or_of_is_zero(this: &[f16]) -> bool { + for &x in this { + if x == f16::ZERO { + return true; + } } + false } +} + +mod reduce_sum_of_x { + // FIXME: add manually-implemented SIMD version + + use half::f16; - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn vector_abs_inplace(this: &mut [f16]) { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_x(this: &[f16]) -> f32 { let n = this.len(); + let mut x = 0.0f32; for i in 0..n { - this[i] = f16::from_f32(this[i].to_f32().abs()); + x += this[i].to_f32(); } + x } +} - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn vector_from_f32(this: &[f32]) -> Vec { +mod reduce_sum_of_abs_x { + // FIXME: add manually-implemented SIMD version + + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_abs_x(this: &[f16]) -> f32 { let n = this.len(); - let mut r = Vec::::with_capacity(n); + let mut x = 0.0f32; for i in 0..n { - unsafe { - r.as_mut_ptr().add(i).write(f16::from_f32(this[i])); - } - } - unsafe { - r.set_len(n); + x += this[i].to_f32().abs(); } - r + x } +} - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn vector_to_f32(this: &[f16]) -> Vec { +mod reduce_sum_of_x2 { + // FIXME: add manually-implemented SIMD version + + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_x2(this: &[f16]) -> f32 { let n = this.len(); - let mut r = Vec::::with_capacity(n); + let mut x2 = 0.0f32; for i in 0..n { - unsafe { - r.as_mut_ptr().add(i).write(this[i].to_f32()); - } - } - unsafe { - r.set_len(n); + x2 += this[i].to_f32() * this[i].to_f32(); } - r + x2 } +} - fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]> { - Self::vector_to_f32(this) - } +mod reduce_min_max_of_x { + // FIXME: add manually-implemented SIMD version - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn kmeans_helper(this: &mut [f16], x: f32, y: f32) { - let x = f16::from_f32(x); - let y = f16::from_f32(y); + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_min_max_of_x(this: &[f16]) -> (f32, f32) { + let mut min = f32::INFINITY; + let mut max = f32::NEG_INFINITY; let n = this.len(); for i in 0..n { - if i % 2 == 0 { - this[i] *= x; - } else { - this[i] *= y; - } + min = min.min(this[i].to_f32()); + max = max.max(this[i].to_f32()); } + (min, max) } } mod reduce_sum_of_xy { use half::f16; + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512fp16")] - pub unsafe fn reduce_sum_of_xy_v4_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + #[crate::simd::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512fp16")] + pub fn reduce_sum_of_xy_v4_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -309,9 +278,9 @@ mod reduce_sum_of_xy { fn reduce_sum_of_xy_v4_avx512fp16_test() { use rand::Rng; const EPSILON: f32 = 2.0; - detect::init(); - if !detect::v4_avx512fp16::detect() { - println!("test {} ... skipped (v4_avx512fp16)", module_path!()); + if !crate::simd::is_cpu_detected!("v4") || !crate::simd::is_feature_detected!("avx512fp16") + { + println!("test {} ... skipped (v4:avx512fp16)", module_path!()); return; } let mut rng = rand::thread_rng(); @@ -327,7 +296,7 @@ mod reduce_sum_of_xy { let lhs = &lhs[..z]; let rhs = &rhs[..z]; let specialized = unsafe { reduce_sum_of_xy_v4_avx512fp16(lhs, rhs) }; - let fallback = unsafe { reduce_sum_of_xy_fallback(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -336,9 +305,10 @@ mod reduce_sum_of_xy { } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - pub unsafe fn reduce_sum_of_xy_v4(lhs: &[f16], rhs: &[f16]) -> f32 { + #[crate::simd::target_cpu(enable = "v4")] + pub fn reduce_sum_of_xy_v4(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -369,8 +339,7 @@ mod reduce_sum_of_xy { fn reduce_sum_of_xy_v4_test() { use rand::Rng; const EPSILON: f32 = 2.0; - detect::init(); - if !detect::v4::detect() { + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -384,7 +353,7 @@ mod reduce_sum_of_xy { .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) .collect::>(); let specialized = unsafe { reduce_sum_of_xy_v4(&lhs, &rhs) }; - let fallback = unsafe { reduce_sum_of_xy_fallback(&lhs, &rhs) }; + let fallback = reduce_sum_of_xy_fallback(&lhs, &rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -392,9 +361,10 @@ mod reduce_sum_of_xy { } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - pub unsafe fn reduce_sum_of_xy_v3(lhs: &[f16], rhs: &[f16]) -> f32 { + #[crate::simd::target_cpu(enable = "v3")] + pub fn reduce_sum_of_xy_v3(lhs: &[f16], rhs: &[f16]) -> f32 { use crate::simd::emulate::emulate_mm256_reduce_add_ps; assert!(lhs.len() == rhs.len()); unsafe { @@ -426,12 +396,178 @@ mod reduce_sum_of_xy { #[cfg(all(target_arch = "x86_64", test))] #[test] - fn reduce_sum_of_xy_v3_test() { + fn reduce_sum_of_xy_v3_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::simd::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v3(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v2")] + #[target_feature(enable = "f16c")] + #[target_feature(enable = "fma")] + pub fn reduce_sum_of_xy_v2_f16c_fma(lhs: &[f16], rhs: &[f16]) -> f32 { + use crate::simd::emulate::emulate_mm_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_cvtph_ps(_mm_loadu_si128(a.cast())); + let y = _mm_cvtph_ps(_mm_loadu_si128(b.cast())); + a = a.add(4); + b = b.add(4); + n -= 4; + xy = _mm_fmadd_ps(x, y, xy); + } + let mut xy = emulate_mm_reduce_add_ps(xy); + while n > 0 { + let x = a.read().to_f32(); + let y = b.read().to_f32(); + a = a.add(1); + b = b.add(1); + n -= 1; + xy += x * y; + } + xy + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_xy_v2_f16c_fma_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::simd::is_cpu_detected!("v2") + || !crate::simd::is_feature_detected!("f16c") + || !crate::simd::is_feature_detected!("fma") + { + println!("test {} ... skipped (v2:f16c:fma)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v2_f16c_fma(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "fp16")] + pub fn reduce_sum_of_xy_v8_3a_fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp16_reduce_sum_of_xy_v8_3a_fp16_unroll( + a: *const (), + b: *const (), + n: usize, + ) -> f32; + } + fp16_reduce_sum_of_xy_v8_3a_fp16_unroll( + lhs.as_ptr().cast(), + rhs.as_ptr().cast(), + lhs.len(), + ) + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_xy_v8_3a_fp16_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::simd::is_cpu_detected!("v8.3a") || !crate::simd::is_feature_detected!("fp16") { + println!("test {} ... skipped (v8.3a:fp16)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v8_3a_fp16(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + pub fn reduce_sum_of_xy_v8_3a_sve(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp16_reduce_sum_of_xy_v8_3a_sve(a: *const (), b: *const (), n: usize) -> f32; + } + fp16_reduce_sum_of_xy_v8_3a_sve(lhs.as_ptr().cast(), rhs.as_ptr().cast(), lhs.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_xy_v8_3a_sve_test() { use rand::Rng; const EPSILON: f32 = 2.0; - detect::init(); - if !detect::v3::detect() { - println!("test {} ... skipped (v3)", module_path!()); + if !crate::simd::is_cpu_detected!("v8.3a") || !crate::simd::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); return; } let mut rng = rand::thread_rng(); @@ -446,8 +582,8 @@ mod reduce_sum_of_xy { for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_xy_v3(lhs, rhs) }; - let fallback = unsafe { reduce_sum_of_xy_fallback(lhs, rhs) }; + let specialized = unsafe { reduce_sum_of_xy_v8_3a_sve(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -456,7 +592,7 @@ mod reduce_sum_of_xy { } } - #[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4:avx512fp16", @"v4", @"v3", @"v2:f16c:fma", @"v8.3a:sve", @"v8.3a:fp16")] pub fn reduce_sum_of_xy(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -471,9 +607,11 @@ mod reduce_sum_of_xy { mod reduce_sum_of_d2 { use half::f16; + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512fp16")] - pub unsafe fn reduce_sum_of_d2_v4_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + #[crate::simd::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512fp16")] + pub fn reduce_sum_of_d2_v4_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -506,9 +644,9 @@ mod reduce_sum_of_d2 { fn reduce_sum_of_d2_v4_avx512fp16_test() { use rand::Rng; const EPSILON: f32 = 6.0; - detect::init(); - if !detect::v4_avx512fp16::detect() { - println!("test {} ... skipped (v4_avx512fp16)", module_path!()); + if !crate::simd::is_cpu_detected!("v4") || !crate::simd::is_feature_detected!("avx512fp16") + { + println!("test {} ... skipped (v4:avx512fp16)", module_path!()); return; } let mut rng = rand::thread_rng(); @@ -524,7 +662,7 @@ mod reduce_sum_of_d2 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; let specialized = unsafe { reduce_sum_of_d2_v4_avx512fp16(lhs, rhs) }; - let fallback = unsafe { reduce_sum_of_d2_fallback(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -533,9 +671,10 @@ mod reduce_sum_of_d2 { } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - pub unsafe fn reduce_sum_of_d2_v4(lhs: &[f16], rhs: &[f16]) -> f32 { + #[crate::simd::target_cpu(enable = "v4")] + pub fn reduce_sum_of_d2_v4(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; @@ -568,8 +707,7 @@ mod reduce_sum_of_d2 { fn reduce_sum_of_d2_v4_test() { use rand::Rng; const EPSILON: f32 = 2.0; - detect::init(); - if !detect::v4::detect() { + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -586,7 +724,7 @@ mod reduce_sum_of_d2 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; let specialized = unsafe { reduce_sum_of_d2_v4(lhs, rhs) }; - let fallback = unsafe { reduce_sum_of_d2_fallback(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -595,9 +733,10 @@ mod reduce_sum_of_d2 { } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - pub unsafe fn reduce_sum_of_d2_v3(lhs: &[f16], rhs: &[f16]) -> f32 { + #[crate::simd::target_cpu(enable = "v3")] + pub fn reduce_sum_of_d2_v3(lhs: &[f16], rhs: &[f16]) -> f32 { use crate::simd::emulate::emulate_mm256_reduce_add_ps; assert!(lhs.len() == rhs.len()); unsafe { @@ -634,8 +773,7 @@ mod reduce_sum_of_d2 { fn reduce_sum_of_d2_v3_test() { use rand::Rng; const EPSILON: f32 = 2.0; - detect::init(); - if !detect::v3::detect() { + if !crate::simd::is_cpu_detected!("v3") { println!("test {} ... skipped (v3)", module_path!()); return; } @@ -652,7 +790,176 @@ mod reduce_sum_of_d2 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; let specialized = unsafe { reduce_sum_of_d2_v3(lhs, rhs) }; - let fallback = unsafe { reduce_sum_of_d2_fallback(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v2")] + #[target_feature(enable = "f16c")] + #[target_feature(enable = "fma")] + pub fn reduce_sum_of_d2_v2_f16c_fma(lhs: &[f16], rhs: &[f16]) -> f32 { + use crate::simd::emulate::emulate_mm_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len() as u32; + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut d2 = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_cvtph_ps(_mm_loadu_si128(a.cast())); + let y = _mm_cvtph_ps(_mm_loadu_si128(b.cast())); + a = a.add(4); + b = b.add(4); + n -= 4; + let d = _mm_sub_ps(x, y); + d2 = _mm_fmadd_ps(d, d, d2); + } + let mut d2 = emulate_mm_reduce_add_ps(d2); + while n > 0 { + let x = a.read().to_f32(); + let y = b.read().to_f32(); + a = a.add(1); + b = b.add(1); + n -= 1; + let d = x - y; + d2 += d * d; + } + d2 + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_d2_v2_f16c_fma_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::simd::is_cpu_detected!("v2") + || !crate::simd::is_feature_detected!("f16c") + || !crate::simd::is_feature_detected!("fma") + { + println!("test {} ... skipped (v2:f16c:fma)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v2_f16c_fma(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "fp16")] + pub fn reduce_sum_of_d2_v8_3a_fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp16_reduce_sum_of_d2_v8_3a_fp16_unroll( + a: *const (), + b: *const (), + n: usize, + ) -> f32; + } + fp16_reduce_sum_of_d2_v8_3a_fp16_unroll( + lhs.as_ptr().cast(), + rhs.as_ptr().cast(), + lhs.len(), + ) + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_d2_v8_3a_fp16_test() { + use rand::Rng; + const EPSILON: f32 = 6.0; + if !crate::simd::is_cpu_detected!("v8.3a") || !crate::simd::is_feature_detected!("fp16") { + println!("test {} ... skipped (v8.3a:fp16)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v8_3a_fp16(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + pub fn reduce_sum_of_d2_v8_3a_sve(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp16_reduce_sum_of_d2_v8_3a_sve(a: *const (), b: *const (), n: usize) -> f32; + } + fp16_reduce_sum_of_d2_v8_3a_sve(lhs.as_ptr().cast(), rhs.as_ptr().cast(), lhs.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_d2_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 6.0; + if !crate::simd::is_cpu_detected!("v8.3a") || !crate::simd::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v8_3a_sve(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -661,7 +968,7 @@ mod reduce_sum_of_d2 { } } - #[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4:avx512fp16", @"v4", @"v3", @"v2:f16c:fma", @"v8.3a:sve", @"v8.3a:fp16")] pub fn reduce_sum_of_d2(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -680,7 +987,7 @@ mod reduce_sum_of_sparse_xy { use half::f16; - #[detect::multiversion(v4, v3, v2, neon, fallback)] + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] pub fn reduce_sum_of_sparse_xy(lidx: &[u32], lval: &[f16], ridx: &[u32], rval: &[f16]) -> f32 { use std::cmp::Ordering; assert_eq!(lidx.len(), lval.len()); @@ -713,7 +1020,7 @@ mod reduce_sum_of_sparse_d2 { use half::f16; - #[detect::multiversion(v4, v3, v2, neon, fallback)] + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] pub fn reduce_sum_of_sparse_d2(lidx: &[u32], lval: &[f16], ridx: &[u32], rval: &[f16]) -> f32 { use std::cmp::Ordering; assert_eq!(lidx.len(), lval.len()); @@ -748,3 +1055,177 @@ mod reduce_sum_of_sparse_d2 { d2 } } + +mod vector_add { + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_add(lhs: &[f16], rhs: &[f16]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] + rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_add_inplace { + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_add_inplace(lhs: &mut [f16], rhs: &[f16]) { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + for i in 0..n { + lhs[i] += rhs[i]; + } + } +} + +mod vector_sub { + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_sub(lhs: &[f16], rhs: &[f16]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] - rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul { + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul(lhs: &[f16], rhs: &[f16]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] * rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul_scalar { + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul_scalar(lhs: &[f16], rhs: f32) -> Vec { + let rhs = f16::from_f32(rhs); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] * rhs); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul_scalar_inplace { + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul_scalar_inplace(lhs: &mut [f16], rhs: f32) { + let rhs = f16::from_f32(rhs); + let n = lhs.len(); + for i in 0..n { + lhs[i] *= rhs; + } + } +} + +mod vector_abs_inplace { + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_abs_inplace(this: &mut [f16]) { + let n = this.len(); + for i in 0..n { + this[i] = f16::from_f32(this[i].to_f32().abs()); + } + } +} + +mod vector_from_f32 { + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_from_f32(this: &[f32]) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(f16::from_f32(this[i])); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_to_f32 { + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_to_f32(this: &[f16]) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(this[i].to_f32()); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod kmeans_helper { + use half::f16; + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn kmeans_helper(this: &mut [f16], x: f32, y: f32) { + let x = f16::from_f32(x); + let y = f16::from_f32(y); + let n = this.len(); + for i in 0..n { + if i % 2 == 0 { + this[i] *= x; + } else { + this[i] *= y; + } + } + } +} diff --git a/crates/base/src/simd/f32.rs b/crates/base/src/simd/f32.rs index 6d1af3e8e..0b66fa44a 100644 --- a/crates/base/src/simd/f32.rs +++ b/crates/base/src/simd/f32.rs @@ -56,15 +56,9 @@ impl ScalarLike for f32 { self } - // FIXME: add manually-implemented SIMD version - #[detect::multiversion(v4, v3, v2, neon, fallback)] + #[inline(always)] fn reduce_or_of_is_zero(this: &[f32]) -> bool { - for &x in this { - if x == 0.0f32 { - return true; - } - } - false + reduce_or_of_is_zero::reduce_or_of_is_zero(this) } #[inline(always)] @@ -107,100 +101,38 @@ impl ScalarLike for f32 { reduce_sum_of_sparse_d2::reduce_sum_of_sparse_d2(lidx, lval, ridx, rval) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] fn vector_add(lhs: &[f32], rhs: &[f32]) -> Vec { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { - unsafe { - r.as_mut_ptr().add(i).write(lhs[i] + rhs[i]); - } - } - unsafe { - r.set_len(n); - } - r + vector_add::vector_add(lhs, rhs) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] fn vector_add_inplace(lhs: &mut [f32], rhs: &[f32]) { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - for i in 0..n { - lhs[i] += rhs[i]; - } + vector_add_inplace::vector_add_inplace(lhs, rhs) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] fn vector_sub(lhs: &[f32], rhs: &[f32]) -> Vec { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { - unsafe { - r.as_mut_ptr().add(i).write(lhs[i] - rhs[i]); - } - } - unsafe { - r.set_len(n); - } - r + vector_sub::vector_sub(lhs, rhs) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] fn vector_mul(lhs: &[f32], rhs: &[f32]) -> Vec { - assert_eq!(lhs.len(), rhs.len()); - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { - unsafe { - r.as_mut_ptr().add(i).write(lhs[i] * rhs[i]); - } - } - unsafe { - r.set_len(n); - } - r + vector_mul::vector_mul(lhs, rhs) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] fn vector_mul_scalar(lhs: &[f32], rhs: f32) -> Vec { - let n = lhs.len(); - let mut r = Vec::::with_capacity(n); - for i in 0..n { - unsafe { - r.as_mut_ptr().add(i).write(lhs[i] * rhs); - } - } - unsafe { - r.set_len(n); - } - r + vector_mul_scalar::vector_mul_scalar(lhs, rhs) } - #[detect::multiversion(v4, v3, v2, neon, fallback)] fn vector_mul_scalar_inplace(lhs: &mut [f32], rhs: f32) { - let n = lhs.len(); - for i in 0..n { - lhs[i] *= rhs; - } + vector_mul_scalar_inplace::vector_mul_scalar_inplace(lhs, rhs); } - #[detect::multiversion(v4, v3, v2, neon, fallback)] fn vector_abs_inplace(this: &mut [f32]) { - let n = this.len(); - for i in 0..n { - this[i] = this[i].abs(); - } + vector_abs_inplace::vector_abs_inplace(this); } - #[detect::multiversion(v4, v3, v2, neon, fallback)] fn vector_from_f32(this: &[f32]) -> Vec { this.to_vec() } - #[detect::multiversion(v4, v3, v2, neon, fallback)] fn vector_to_f32(this: &[f32]) -> Vec { this.to_vec() } @@ -209,23 +141,30 @@ impl ScalarLike for f32 { this } - #[detect::multiversion(v4, v3, v2, neon, fallback)] fn kmeans_helper(this: &mut [f32], x: f32, y: f32) { - let n = this.len(); - for i in 0..n { - if i % 2 == 0 { - this[i] *= x; - } else { - this[i] *= y; + kmeans_helper::kmeans_helper(this, x, y) + } +} + +mod reduce_or_of_is_zero { + // FIXME: add manually-implemented SIMD version + + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_or_of_is_zero(this: &[f32]) -> bool { + for &x in this { + if x == 0.0f32 { + return true; } } + false } } mod reduce_sum_of_x { + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn reduce_sum_of_x_v4(this: &[f32]) -> f32 { + #[crate::simd::target_cpu(enable = "v4")] + fn reduce_sum_of_x_v4(this: &[f32]) -> f32 { unsafe { use std::arch::x86_64::*; let mut n = this.len(); @@ -251,8 +190,7 @@ mod reduce_sum_of_x { fn reduce_sum_of_x_v4_test() { use rand::Rng; const EPSILON: f32 = 0.008; - detect::init(); - if !detect::v4::detect() { + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -264,8 +202,8 @@ mod reduce_sum_of_x { .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x_v4(&this) }; - let fallback = unsafe { reduce_sum_of_x_fallback(&this) }; + let specialized = unsafe { reduce_sum_of_x_v4(this) }; + let fallback = reduce_sum_of_x_fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -274,9 +212,10 @@ mod reduce_sum_of_x { } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - unsafe fn reduce_sum_of_x_v3(this: &[f32]) -> f32 { + #[crate::simd::target_cpu(enable = "v3")] + fn reduce_sum_of_x_v3(this: &[f32]) -> f32 { use crate::simd::emulate::emulate_mm256_reduce_add_ps; unsafe { use std::arch::x86_64::*; @@ -312,8 +251,7 @@ mod reduce_sum_of_x { fn reduce_sum_of_x_v3_test() { use rand::Rng; const EPSILON: f32 = 0.008; - detect::init(); - if !detect::v3::detect() { + if !crate::simd::is_cpu_detected!("v3") { println!("test {} ... skipped (v3)", module_path!()); return; } @@ -326,7 +264,7 @@ mod reduce_sum_of_x { for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x_v3(this) }; - let fallback = unsafe { reduce_sum_of_x_fallback(this) }; + let fallback = reduce_sum_of_x_fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -335,51 +273,41 @@ mod reduce_sum_of_x { } } - #[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)] - pub fn reduce_sum_of_x(this: &[f32]) -> f32 { - let n = this.len(); - let mut sum = 0.0f32; - for i in 0..n { - sum += this[i]; - } - sum - } -} - -mod reduce_sum_of_abs_x { + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn reduce_sum_of_abs_x_v4(this: &[f32]) -> f32 { + #[crate::simd::target_cpu(enable = "v2")] + fn reduce_sum_of_x_v2(this: &[f32]) -> f32 { + use crate::simd::emulate::emulate_mm_reduce_add_ps; unsafe { use std::arch::x86_64::*; let mut n = this.len(); let mut a = this.as_ptr(); - let mut sum = _mm512_setzero_ps(); - while n >= 16 { - let x = _mm512_loadu_ps(a); - let abs_x = _mm512_abs_ps(x); - a = a.add(16); - n -= 16; - sum = _mm512_add_ps(abs_x, sum); + let mut sum = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_loadu_ps(a); + a = a.add(4); + n -= 4; + sum = _mm_add_ps(x, sum); } - if n > 0 { - let mask = _bzhi_u32(0xffff, n as u32) as u16; - let x = _mm512_maskz_loadu_ps(mask, a); - let abs_x = _mm512_abs_ps(x); - sum = _mm512_add_ps(abs_x, sum); + let mut sum = emulate_mm_reduce_add_ps(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x; } - _mm512_reduce_add_ps(sum) + sum } } #[cfg(all(target_arch = "x86_64", test))] #[test] - fn reduce_sum_of_abs_x_v4_test() { + fn reduce_sum_of_x_v2_test() { use rand::Rng; const EPSILON: f32 = 0.008; - detect::init(); - if !detect::v4::detect() { - println!("test {} ... skipped (v4)", module_path!()); + if !crate::simd::is_cpu_detected!("v2") { + println!("test {} ... skipped (v2)", module_path!()); return; } let mut rng = rand::thread_rng(); @@ -390,8 +318,8 @@ mod reduce_sum_of_abs_x { .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_abs_x_v4(&this) }; - let fallback = unsafe { reduce_sum_of_abs_x_fallback(&this) }; + let specialized = unsafe { reduce_sum_of_x_v2(this) }; + let fallback = reduce_sum_of_x_fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -400,51 +328,40 @@ mod reduce_sum_of_abs_x { } } - #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - unsafe fn reduce_sum_of_abs_x_v3(this: &[f32]) -> f32 { - use crate::simd::emulate::emulate_mm256_reduce_add_ps; + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_x_v8_3a(this: &[f32]) -> f32 { unsafe { - use std::arch::x86_64::*; - let abs = _mm256_castsi256_ps(_mm256_srli_epi32(_mm256_set1_epi32(-1), 1)); + use std::arch::aarch64::*; let mut n = this.len(); let mut a = this.as_ptr(); - let mut sum = _mm256_setzero_ps(); - while n >= 8 { - let x = _mm256_loadu_ps(a); - let abs_x = _mm256_and_ps(abs, x); - a = a.add(8); - n -= 8; - sum = _mm256_add_ps(abs_x, sum); - } - if n >= 4 { - let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); - let abs_x = _mm256_and_ps(abs, x); + let mut sum = vdupq_n_f32(0.0); + while n >= 4 { + let x = vld1q_f32(a); a = a.add(4); n -= 4; - sum = _mm256_add_ps(abs_x, sum); + sum = vaddq_f32(x, sum); } - let mut sum = emulate_mm256_reduce_add_ps(sum); + let mut sum = vaddvq_f32(sum); // this hint is used to disable loop unrolling while std::hint::black_box(n) > 0 { let x = a.read(); - let abs_x = x.abs(); a = a.add(1); n -= 1; - sum += abs_x; + sum += x; } sum } } - #[cfg(all(target_arch = "x86_64", test))] + #[cfg(all(target_arch = "aarch64", test))] #[test] - fn reduce_sum_of_abs_x_v3_test() { + fn reduce_sum_of_x_v8_3a_test() { use rand::Rng; const EPSILON: f32 = 0.008; - detect::init(); - if !detect::v3::detect() { - println!("test {} ... skipped (v3)", module_path!()); + if !crate::simd::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8_3a)", module_path!()); return; } let mut rng = rand::thread_rng(); @@ -455,8 +372,8 @@ mod reduce_sum_of_abs_x { .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_abs_x_v3(this) }; - let fallback = unsafe { reduce_sum_of_abs_x_fallback(this) }; + let specialized = unsafe { reduce_sum_of_x_v8_3a(this) }; + let fallback = reduce_sum_of_x_fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -465,48 +382,90 @@ mod reduce_sum_of_abs_x { } } - #[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)] - pub fn reduce_sum_of_abs_x(this: &[f32]) -> f32 { + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_sum_of_x_v8_3a_sve(this: &[f32]) -> f32 { + unsafe { + extern "C" { + fn fp32_reduce_sum_of_x_v8_3a_sve(this: *const f32, n: usize) -> f32; + } + fp32_reduce_sum_of_x_v8_3a_sve(this.as_ptr(), this.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_x_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::simd::is_cpu_detected!("v8.3a") || !crate::simd::is_feature_detected!("sve") { + println!("test {} ... skipped (v8_3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_v8_3a_sve(this) }; + let fallback = reduce_sum_of_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::simd::multiversion(@"v4", @"v3", @"v2", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_sum_of_x(this: &[f32]) -> f32 { let n = this.len(); let mut sum = 0.0f32; for i in 0..n { - sum += this[i].abs(); + sum += this[i]; } sum } } -mod reduce_sum_of_x2 { +mod reduce_sum_of_abs_x { + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn reduce_sum_of_x2_v4(this: &[f32]) -> f32 { + #[crate::simd::target_cpu(enable = "v4")] + fn reduce_sum_of_abs_x_v4(this: &[f32]) -> f32 { unsafe { use std::arch::x86_64::*; let mut n = this.len(); let mut a = this.as_ptr(); - let mut x2 = _mm512_setzero_ps(); + let mut sum = _mm512_setzero_ps(); while n >= 16 { let x = _mm512_loadu_ps(a); + let abs_x = _mm512_abs_ps(x); a = a.add(16); n -= 16; - x2 = _mm512_fmadd_ps(x, x, x2); + sum = _mm512_add_ps(abs_x, sum); } if n > 0 { let mask = _bzhi_u32(0xffff, n as u32) as u16; let x = _mm512_maskz_loadu_ps(mask, a); - x2 = _mm512_fmadd_ps(x, x, x2); + let abs_x = _mm512_abs_ps(x); + sum = _mm512_add_ps(abs_x, sum); } - _mm512_reduce_add_ps(x2) + _mm512_reduce_add_ps(sum) } } #[cfg(all(target_arch = "x86_64", test))] #[test] - fn reduce_sum_of_x2_v4_test() { + fn reduce_sum_of_abs_x_v4_test() { use rand::Rng; - const EPSILON: f32 = 0.006; - detect::init(); - if !detect::v4::detect() { + const EPSILON: f32 = 0.008; + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -518,8 +477,8 @@ mod reduce_sum_of_x2 { .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x2_v4(&this) }; - let fallback = unsafe { reduce_sum_of_x2_fallback(&this) }; + let specialized = unsafe { reduce_sum_of_abs_x_v4(this) }; + let fallback = reduce_sum_of_abs_x_fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -528,46 +487,50 @@ mod reduce_sum_of_x2 { } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - unsafe fn reduce_sum_of_x2_v3(this: &[f32]) -> f32 { + #[crate::simd::target_cpu(enable = "v3")] + fn reduce_sum_of_abs_x_v3(this: &[f32]) -> f32 { use crate::simd::emulate::emulate_mm256_reduce_add_ps; unsafe { use std::arch::x86_64::*; + let abs = _mm256_castsi256_ps(_mm256_srli_epi32(_mm256_set1_epi32(-1), 1)); let mut n = this.len(); let mut a = this.as_ptr(); - let mut x2 = _mm256_setzero_ps(); + let mut sum = _mm256_setzero_ps(); while n >= 8 { let x = _mm256_loadu_ps(a); + let abs_x = _mm256_and_ps(abs, x); a = a.add(8); n -= 8; - x2 = _mm256_fmadd_ps(x, x, x2); + sum = _mm256_add_ps(abs_x, sum); } if n >= 4 { let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); + let abs_x = _mm256_and_ps(abs, x); a = a.add(4); n -= 4; - x2 = _mm256_fmadd_ps(x, x, x2); + sum = _mm256_add_ps(abs_x, sum); } - let mut x2 = emulate_mm256_reduce_add_ps(x2); + let mut sum = emulate_mm256_reduce_add_ps(sum); // this hint is used to disable loop unrolling while std::hint::black_box(n) > 0 { let x = a.read(); + let abs_x = x.abs(); a = a.add(1); n -= 1; - x2 += x * x; + sum += abs_x; } - x2 + sum } } #[cfg(all(target_arch = "x86_64", test))] #[test] - fn reduce_sum_of_x2_v3_test() { + fn reduce_sum_of_abs_x_v3_test() { use rand::Rng; - const EPSILON: f32 = 0.006; - detect::init(); - if !detect::v3::detect() { + const EPSILON: f32 = 0.008; + if !crate::simd::is_cpu_detected!("v3") { println!("test {} ... skipped (v3)", module_path!()); return; } @@ -579,8 +542,8 @@ mod reduce_sum_of_x2 { .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x2_v3(this) }; - let fallback = unsafe { reduce_sum_of_x2_fallback(this) }; + let specialized = unsafe { reduce_sum_of_abs_x_v3(this) }; + let fallback = reduce_sum_of_abs_x_fallback(this); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -589,180 +552,1088 @@ mod reduce_sum_of_x2 { } } - #[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)] - pub fn reduce_sum_of_x2(this: &[f32]) -> f32 { - let n = this.len(); - let mut x2 = 0.0f32; - for i in 0..n { - x2 += this[i] * this[i]; - } - x2 - } -} - -mod reduce_min_max_of_x { - // Semanctics of `f32::min` is different from `_mm256_min_ps`, - // which may lead to issues... - + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn reduce_min_max_of_x_v4(this: &[f32]) -> (f32, f32) { + #[crate::simd::target_cpu(enable = "v2")] + fn reduce_sum_of_abs_x_v2(this: &[f32]) -> f32 { + use crate::simd::emulate::emulate_mm_reduce_add_ps; unsafe { use std::arch::x86_64::*; + let abs = _mm_castsi128_ps(_mm_srli_epi32(_mm_set1_epi32(-1), 1)); let mut n = this.len(); let mut a = this.as_ptr(); - let mut min = _mm512_set1_ps(f32::INFINITY); - let mut max = _mm512_set1_ps(f32::NEG_INFINITY); - while n >= 16 { - let x = _mm512_loadu_ps(a); - a = a.add(16); - n -= 16; - min = _mm512_min_ps(x, min); - max = _mm512_max_ps(x, max); + let mut sum = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_loadu_ps(a); + let abs_x = _mm_and_ps(abs, x); + a = a.add(4); + n -= 4; + sum = _mm_add_ps(abs_x, sum); } - if n > 0 { - let mask = _bzhi_u32(0xffff, n as u32) as u16; - let x = _mm512_maskz_loadu_ps(mask, a); - min = _mm512_mask_min_ps(min, mask, x, min); - max = _mm512_mask_max_ps(max, mask, x, max); + let mut sum = emulate_mm_reduce_add_ps(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let abs_x = x.abs(); + a = a.add(1); + n -= 1; + sum += abs_x; } - let min = _mm512_reduce_min_ps(min); - let max = _mm512_reduce_max_ps(max); - (min, max) + sum } } #[cfg(all(target_arch = "x86_64", test))] #[test] - fn reduce_min_max_of_x_v4_test() { + fn reduce_sum_of_abs_x_v2_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::simd::is_cpu_detected!("v2") { + println!("test {} ... skipped (v2)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_v2(this) }; + let fallback = reduce_sum_of_abs_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_abs_x_v8_3a(this: &[f32]) -> f32 { + unsafe { + use std::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = vdupq_n_f32(0.0); + while n >= 4 { + let x = vld1q_f32(a); + let abs_x = vabsq_f32(x); + a = a.add(4); + n -= 4; + sum = vaddq_f32(abs_x, sum); + } + let mut sum = vaddvq_f32(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let abs_x = x.abs(); + a = a.add(1); + n -= 1; + sum += abs_x; + } + sum + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_abs_x_v8_3a_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::simd::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_v8_3a(this) }; + let fallback = reduce_sum_of_abs_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_sum_of_abs_x_v8_3a_sve(this: &[f32]) -> f32 { + unsafe { + extern "C" { + fn fp32_reduce_sum_of_abs_x_v8_3a_sve(this: *const f32, n: usize) -> f32; + } + fp32_reduce_sum_of_abs_x_v8_3a_sve(this.as_ptr(), this.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_abs_x_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::simd::is_cpu_detected!("v8.3a") || !crate::simd::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_v8_3a_sve(this) }; + let fallback = reduce_sum_of_abs_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::simd::multiversion(@"v4", @"v3", @"v2", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_sum_of_abs_x(this: &[f32]) -> f32 { + let n = this.len(); + let mut sum = 0.0f32; + for i in 0..n { + sum += this[i].abs(); + } + sum + } +} + +mod reduce_sum_of_x2 { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v4")] + fn reduce_sum_of_x2_v4(this: &[f32]) -> f32 { + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut x2 = _mm512_setzero_ps(); + while n >= 16 { + let x = _mm512_loadu_ps(a); + a = a.add(16); + n -= 16; + x2 = _mm512_fmadd_ps(x, x, x2); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm512_maskz_loadu_ps(mask, a); + x2 = _mm512_fmadd_ps(x, x, x2); + } + _mm512_reduce_add_ps(x2) + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x2_v4_test() { use rand::Rng; - detect::init(); - if !detect::v4::detect() { + const EPSILON: f32 = 0.006; + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } let mut rng = rand::thread_rng(); for _ in 0..256 { - let n = 200; - let x = (0..n) + let n = 4016; + let this = (0..n) .map(|_| rng.gen_range(-1.0..=1.0)) .collect::>(); - for z in 50..200 { - let x = &x[..z]; - let specialized = unsafe { reduce_min_max_of_x_v4(x) }; - let fallback = unsafe { reduce_min_max_of_x_fallback(x) }; - assert_eq!(specialized.0, fallback.0); - assert_eq!(specialized.1, fallback.1); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v4(this) }; + let fallback = reduce_sum_of_x2_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); } } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - unsafe fn reduce_min_max_of_x_v3(this: &[f32]) -> (f32, f32) { - use crate::simd::emulate::emulate_mm256_reduce_max_ps; - use crate::simd::emulate::emulate_mm256_reduce_min_ps; + #[crate::simd::target_cpu(enable = "v3")] + fn reduce_sum_of_x2_v3(this: &[f32]) -> f32 { + use crate::simd::emulate::emulate_mm256_reduce_add_ps; unsafe { use std::arch::x86_64::*; let mut n = this.len(); let mut a = this.as_ptr(); - let mut min = _mm256_set1_ps(f32::INFINITY); - let mut max = _mm256_set1_ps(f32::NEG_INFINITY); + let mut x2 = _mm256_setzero_ps(); while n >= 8 { let x = _mm256_loadu_ps(a); a = a.add(8); n -= 8; - min = _mm256_min_ps(x, min); - max = _mm256_max_ps(x, max); + x2 = _mm256_fmadd_ps(x, x, x2); } - let mut min = emulate_mm256_reduce_min_ps(min); - let mut max = emulate_mm256_reduce_max_ps(max); + if n >= 4 { + let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); + a = a.add(4); + n -= 4; + x2 = _mm256_fmadd_ps(x, x, x2); + } + let mut x2 = emulate_mm256_reduce_add_ps(x2); // this hint is used to disable loop unrolling while std::hint::black_box(n) > 0 { let x = a.read(); a = a.add(1); n -= 1; - min = x.min(min); - max = x.max(max); + x2 += x * x; } - (min, max) + x2 } } #[cfg(all(target_arch = "x86_64", test))] #[test] - fn reduce_min_max_of_x_v3_test() { + fn reduce_sum_of_x2_v3_test() { use rand::Rng; - detect::init(); - if !detect::v3::detect() { + const EPSILON: f32 = 0.006; + if !crate::simd::is_cpu_detected!("v3") { println!("test {} ... skipped (v3)", module_path!()); return; } let mut rng = rand::thread_rng(); for _ in 0..256 { - let n = 200; - let x = (0..n) + let n = 4016; + let this = (0..n) .map(|_| rng.gen_range(-1.0..=1.0)) .collect::>(); - for z in 50..200 { - let x = &x[..z]; - let specialized = unsafe { reduce_min_max_of_x_v3(x) }; - let fallback = unsafe { reduce_min_max_of_x_fallback(x) }; - assert_eq!(specialized.0, fallback.0,); - assert_eq!(specialized.1, fallback.1,); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v3(this) }; + let fallback = reduce_sum_of_x2_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v2")] + #[target_feature(enable = "fma")] + fn reduce_sum_of_x2_v2_fma(this: &[f32]) -> f32 { + use crate::simd::emulate::emulate_mm_reduce_add_ps; + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut x2 = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_loadu_ps(a); + a = a.add(4); + n -= 4; + x2 = _mm_fmadd_ps(x, x, x2); + } + let mut x2 = emulate_mm_reduce_add_ps(x2); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + x2 += x * x; + } + x2 + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x2_v2_fma_test() { + use rand::Rng; + const EPSILON: f32 = 0.006; + if !crate::simd::is_cpu_detected!("v2") || !crate::simd::is_feature_detected!("fma") { + println!("test {} ... skipped (v2:fma)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v2_fma(this) }; + let fallback = reduce_sum_of_x2_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_x2_v8_3a(this: &[f32]) -> f32 { + unsafe { + use std::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut x2 = vdupq_n_f32(0.0); + while n >= 4 { + let x = vld1q_f32(a); + a = a.add(4); + n -= 4; + x2 = vfmaq_f32(x2, x, x); + } + let mut x2 = vaddvq_f32(x2); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + x2 += x * x; + } + x2 + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_x2_v8_3a_test() { + use rand::Rng; + const EPSILON: f32 = 0.006; + if !crate::simd::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v8_3a(this) }; + let fallback = reduce_sum_of_x2_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_sum_of_x2_v8_3a_sve(this: &[f32]) -> f32 { + unsafe { + extern "C" { + fn fp32_reduce_sum_of_x2_v8_3a_sve(this: *const f32, n: usize) -> f32; + } + fp32_reduce_sum_of_x2_v8_3a_sve(this.as_ptr(), this.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_x2_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 0.006; + if !crate::simd::is_cpu_detected!("v8.3a") || !crate::simd::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v8_3a_sve(this) }; + let fallback = reduce_sum_of_x2_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::simd::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_sum_of_x2(this: &[f32]) -> f32 { + let n = this.len(); + let mut x2 = 0.0f32; + for i in 0..n { + x2 += this[i] * this[i]; + } + x2 + } +} + +mod reduce_min_max_of_x { + // Semanctics of `f32::min` is different from `_mm256_min_ps`, + // which may lead to issues... + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v4")] + fn reduce_min_max_of_x_v4(this: &[f32]) -> (f32, f32) { + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = _mm512_set1_ps(f32::INFINITY); + let mut max = _mm512_set1_ps(f32::NEG_INFINITY); + while n >= 16 { + let x = _mm512_loadu_ps(a); + a = a.add(16); + n -= 16; + min = _mm512_min_ps(x, min); + max = _mm512_max_ps(x, max); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm512_maskz_loadu_ps(mask, a); + min = _mm512_mask_min_ps(min, mask, x, min); + max = _mm512_mask_max_ps(max, mask, x, max); + } + let min = _mm512_reduce_min_ps(min); + let max = _mm512_reduce_max_ps(max); + (min, max) + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_min_max_of_x_v4_test() { + use rand::Rng; + if !crate::simd::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 200; + let x = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v4(x) }; + let fallback = reduce_min_max_of_x_fallback(x); + assert_eq!(specialized.0, fallback.0); + assert_eq!(specialized.1, fallback.1); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v3")] + fn reduce_min_max_of_x_v3(this: &[f32]) -> (f32, f32) { + use crate::simd::emulate::emulate_mm256_reduce_max_ps; + use crate::simd::emulate::emulate_mm256_reduce_min_ps; + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = _mm256_set1_ps(f32::INFINITY); + let mut max = _mm256_set1_ps(f32::NEG_INFINITY); + while n >= 8 { + let x = _mm256_loadu_ps(a); + a = a.add(8); + n -= 8; + min = _mm256_min_ps(x, min); + max = _mm256_max_ps(x, max); + } + let mut min = emulate_mm256_reduce_min_ps(min); + let mut max = emulate_mm256_reduce_max_ps(max); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + min = x.min(min); + max = x.max(max); + } + (min, max) + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_min_max_of_x_v3_test() { + use rand::Rng; + if !crate::simd::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 200; + let x = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v3(x) }; + let fallback = reduce_min_max_of_x_fallback(x); + assert_eq!(specialized.0, fallback.0,); + assert_eq!(specialized.1, fallback.1,); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v2")] + fn reduce_min_max_of_x_v2(this: &[f32]) -> (f32, f32) { + use crate::simd::emulate::emulate_mm_reduce_max_ps; + use crate::simd::emulate::emulate_mm_reduce_min_ps; + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = _mm_set1_ps(f32::INFINITY); + let mut max = _mm_set1_ps(f32::NEG_INFINITY); + while n >= 4 { + let x = _mm_loadu_ps(a); + a = a.add(4); + n -= 4; + min = _mm_min_ps(x, min); + max = _mm_max_ps(x, max); + } + let mut min = emulate_mm_reduce_min_ps(min); + let mut max = emulate_mm_reduce_max_ps(max); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + min = x.min(min); + max = x.max(max); + } + (min, max) + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_min_max_of_x_v2_test() { + use rand::Rng; + if !crate::simd::is_cpu_detected!("v2") { + println!("test {} ... skipped (v2)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 200; + let x = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v2(x) }; + let fallback = reduce_min_max_of_x_fallback(x); + assert_eq!(specialized.0, fallback.0,); + assert_eq!(specialized.1, fallback.1,); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + fn reduce_min_max_of_x_v8_3a(this: &[f32]) -> (f32, f32) { + unsafe { + use std::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = vdupq_n_f32(f32::INFINITY); + let mut max = vdupq_n_f32(f32::NEG_INFINITY); + while n >= 4 { + let x = vld1q_f32(a); + a = a.add(4); + n -= 4; + min = vminq_f32(x, min); + max = vmaxq_f32(x, max); + } + let mut min = vminvq_f32(min); + let mut max = vmaxvq_f32(max); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + min = x.min(min); + max = x.max(max); + } + (min, max) + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_min_max_of_x_v8_3a_test() { + use rand::Rng; + if !crate::simd::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 200; + let x = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v8_3a(x) }; + let fallback = reduce_min_max_of_x_fallback(x); + assert_eq!(specialized.0, fallback.0,); + assert_eq!(specialized.1, fallback.1,); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_min_max_of_x_v8_3a_sve(this: &[f32]) -> (f32, f32) { + let mut min = f32::INFINITY; + let mut max = -f32::INFINITY; + unsafe { + extern "C" { + fn fp32_reduce_min_max_of_x_v8_3a_sve( + this: *const f32, + n: usize, + out_min: &mut f32, + out_max: &mut f32, + ); + } + fp32_reduce_min_max_of_x_v8_3a_sve(this.as_ptr(), this.len(), &mut min, &mut max); + } + (min, max) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_min_max_of_x_v8_3a_sve_test() { + use rand::Rng; + if !crate::simd::is_cpu_detected!("v8.3a") || !crate::simd::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 200; + let x = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v8_3a_sve(x) }; + let fallback = reduce_min_max_of_x_fallback(x); + assert_eq!(specialized.0, fallback.0,); + assert_eq!(specialized.1, fallback.1,); + } + } + } + + #[crate::simd::multiversion(@"v4", @"v3", @"v2", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_min_max_of_x(this: &[f32]) -> (f32, f32) { + let mut min = f32::INFINITY; + let mut max = f32::NEG_INFINITY; + let n = this.len(); + for i in 0..n { + min = min.min(this[i]); + max = max.max(this[i]); + } + (min, max) + } +} + +mod reduce_sum_of_xy { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v4")] + fn reduce_sum_of_xy_v4(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm512_setzero_ps(); + while n >= 16 { + let x = _mm512_loadu_ps(a); + let y = _mm512_loadu_ps(b); + a = a.add(16); + b = b.add(16); + n -= 16; + xy = _mm512_fmadd_ps(x, y, xy); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm512_maskz_loadu_ps(mask, a); + let y = _mm512_maskz_loadu_ps(mask, b); + xy = _mm512_fmadd_ps(x, y, xy); + } + _mm512_reduce_add_ps(xy) + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_xy_v4_test() { + use rand::Rng; + const EPSILON: f32 = 0.004; + if !crate::simd::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v4(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v3")] + fn reduce_sum_of_xy_v3(lhs: &[f32], rhs: &[f32]) -> f32 { + use crate::simd::emulate::emulate_mm256_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm256_setzero_ps(); + while n >= 8 { + let x = _mm256_loadu_ps(a); + let y = _mm256_loadu_ps(b); + a = a.add(8); + b = b.add(8); + n -= 8; + xy = _mm256_fmadd_ps(x, y, xy); + } + if n >= 4 { + let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); + let y = _mm256_zextps128_ps256(_mm_loadu_ps(b)); + a = a.add(4); + b = b.add(4); + n -= 4; + xy = _mm256_fmadd_ps(x, y, xy); + } + let mut xy = emulate_mm256_reduce_add_ps(xy); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let y = b.read(); + a = a.add(1); + b = b.add(1); + n -= 1; + xy += x * y; + } + xy + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_xy_v3_test() { + use rand::Rng; + const EPSILON: f32 = 0.004; + if !crate::simd::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v3(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v2")] + #[target_feature(enable = "fma")] + fn reduce_sum_of_xy_v2_fma(lhs: &[f32], rhs: &[f32]) -> f32 { + use crate::simd::emulate::emulate_mm_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_loadu_ps(a); + let y = _mm_loadu_ps(b); + a = a.add(4); + b = b.add(4); + n -= 4; + xy = _mm_fmadd_ps(x, y, xy); + } + let mut xy = emulate_mm_reduce_add_ps(xy); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let y = b.read(); + a = a.add(1); + b = b.add(1); + n -= 1; + xy += x * y; + } + xy + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_xy_v2_fma_test() { + use rand::Rng; + const EPSILON: f32 = 0.004; + if !crate::simd::is_cpu_detected!("v2") || !crate::simd::is_feature_detected!("fma") { + println!("test {} ... skipped (v2:fma)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v2_fma(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_xy_v8_3a(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::aarch64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = vdupq_n_f32(0.0); + while n >= 4 { + let x = vld1q_f32(a); + let y = vld1q_f32(b); + a = a.add(4); + b = b.add(4); + n -= 4; + xy = vfmaq_f32(xy, x, y); + } + let mut xy = vaddvq_f32(xy); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let y = b.read(); + a = a.add(1); + b = b.add(1); + n -= 1; + xy += x * y; + } + xy + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_xy_v8_3a_test() { + use rand::Rng; + const EPSILON: f32 = 0.004; + if !crate::simd::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v8_3a(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_sum_of_xy_v8_3a_sve(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp32_reduce_sum_of_xy_v8_3a_sve(a: *const f32, b: *const f32, n: usize) -> f32; + } + fp32_reduce_sum_of_xy_v8_3a_sve(lhs.as_ptr(), rhs.as_ptr(), lhs.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_xy_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 0.004; + if !crate::simd::is_cpu_detected!("v8.3a") || !crate::simd::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v8_3a_sve(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); } } } - #[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)] - pub fn reduce_min_max_of_x(this: &[f32]) -> (f32, f32) { - let mut min = f32::INFINITY; - let mut max = f32::NEG_INFINITY; - let n = this.len(); + #[crate::simd::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_sum_of_xy(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = 0.0f32; for i in 0..n { - min = min.min(this[i]); - max = max.max(this[i]); + xy += lhs[i] * rhs[i]; } - (min, max) + xy } } -mod reduce_sum_of_xy { +mod reduce_sum_of_d2 { + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn reduce_sum_of_xy_v4(lhs: &[f32], rhs: &[f32]) -> f32 { + #[crate::simd::target_cpu(enable = "v4")] + fn reduce_sum_of_d2_v4(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; - let mut n = lhs.len(); + let mut n = lhs.len() as u32; let mut a = lhs.as_ptr(); let mut b = rhs.as_ptr(); - let mut xy = _mm512_setzero_ps(); + let mut d2 = _mm512_setzero_ps(); while n >= 16 { let x = _mm512_loadu_ps(a); let y = _mm512_loadu_ps(b); a = a.add(16); b = b.add(16); n -= 16; - xy = _mm512_fmadd_ps(x, y, xy); + let d = _mm512_sub_ps(x, y); + d2 = _mm512_fmadd_ps(d, d, d2); } if n > 0 { - let mask = _bzhi_u32(0xffff, n as u32) as u16; + let mask = _bzhi_u32(0xffff, n) as u16; let x = _mm512_maskz_loadu_ps(mask, a); let y = _mm512_maskz_loadu_ps(mask, b); - xy = _mm512_fmadd_ps(x, y, xy); + let d = _mm512_sub_ps(x, y); + d2 = _mm512_fmadd_ps(d, d, d2); } - _mm512_reduce_add_ps(xy) + _mm512_reduce_add_ps(d2) } } #[cfg(all(target_arch = "x86_64", test))] #[test] - fn reduce_sum_of_xy_v4_test() { + fn reduce_sum_of_d2_v4_test() { use rand::Rng; - const EPSILON: f32 = 0.004; - detect::init(); - if !detect::v4::detect() { + const EPSILON: f32 = 0.02; + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -778,8 +1649,8 @@ mod reduce_sum_of_xy { for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_xy_v4(lhs, rhs) }; - let fallback = unsafe { reduce_sum_of_xy_fallback(lhs, rhs) }; + let specialized = unsafe { reduce_sum_of_d2_v4(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -788,9 +1659,10 @@ mod reduce_sum_of_xy { } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - unsafe fn reduce_sum_of_xy_v3(lhs: &[f32], rhs: &[f32]) -> f32 { + #[crate::simd::target_cpu(enable = "v3")] + fn reduce_sum_of_d2_v3(lhs: &[f32], rhs: &[f32]) -> f32 { use crate::simd::emulate::emulate_mm256_reduce_add_ps; assert!(lhs.len() == rhs.len()); unsafe { @@ -798,14 +1670,15 @@ mod reduce_sum_of_xy { let mut n = lhs.len(); let mut a = lhs.as_ptr(); let mut b = rhs.as_ptr(); - let mut xy = _mm256_setzero_ps(); + let mut d2 = _mm256_setzero_ps(); while n >= 8 { let x = _mm256_loadu_ps(a); let y = _mm256_loadu_ps(b); a = a.add(8); b = b.add(8); n -= 8; - xy = _mm256_fmadd_ps(x, y, xy); + let d = _mm256_sub_ps(x, y); + d2 = _mm256_fmadd_ps(d, d, d2); } if n >= 4 { let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); @@ -813,9 +1686,10 @@ mod reduce_sum_of_xy { a = a.add(4); b = b.add(4); n -= 4; - xy = _mm256_fmadd_ps(x, y, xy); + let d = _mm256_sub_ps(x, y); + d2 = _mm256_fmadd_ps(d, d, d2); } - let mut xy = emulate_mm256_reduce_add_ps(xy); + let mut d2 = emulate_mm256_reduce_add_ps(d2); // this hint is used to disable loop unrolling while std::hint::black_box(n) > 0 { let x = a.read(); @@ -823,19 +1697,19 @@ mod reduce_sum_of_xy { a = a.add(1); b = b.add(1); n -= 1; - xy += x * y; + let d = x - y; + d2 += d * d; } - xy + d2 } } #[cfg(all(target_arch = "x86_64", test))] #[test] - fn reduce_sum_of_xy_v3_test() { + fn reduce_sum_of_d2_v3_test() { use rand::Rng; - const EPSILON: f32 = 0.004; - detect::init(); - if !detect::v3::detect() { + const EPSILON: f32 = 0.02; + if !crate::simd::is_cpu_detected!("v3") { println!("test {} ... skipped (v3)", module_path!()); return; } @@ -851,8 +1725,8 @@ mod reduce_sum_of_xy { for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_xy_v3(lhs, rhs) }; - let fallback = unsafe { reduce_sum_of_xy_fallback(lhs, rhs) }; + let specialized = unsafe { reduce_sum_of_d2_v3(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -861,57 +1735,50 @@ mod reduce_sum_of_xy { } } - #[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)] - pub fn reduce_sum_of_xy(lhs: &[f32], rhs: &[f32]) -> f32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut xy = 0.0f32; - for i in 0..n { - xy += lhs[i] * rhs[i]; - } - xy - } -} - -mod reduce_sum_of_d2 { + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn reduce_sum_of_d2_v4(lhs: &[f32], rhs: &[f32]) -> f32 { + #[crate::simd::target_cpu(enable = "v2")] + #[target_feature(enable = "fma")] + fn reduce_sum_of_d2_v2_fma(lhs: &[f32], rhs: &[f32]) -> f32 { + use crate::simd::emulate::emulate_mm_reduce_add_ps; assert!(lhs.len() == rhs.len()); unsafe { use std::arch::x86_64::*; - let mut n = lhs.len() as u32; + let mut n = lhs.len(); let mut a = lhs.as_ptr(); let mut b = rhs.as_ptr(); - let mut d2 = _mm512_setzero_ps(); - while n >= 16 { - let x = _mm512_loadu_ps(a); - let y = _mm512_loadu_ps(b); - a = a.add(16); - b = b.add(16); - n -= 16; - let d = _mm512_sub_ps(x, y); - d2 = _mm512_fmadd_ps(d, d, d2); + let mut d2 = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_loadu_ps(a); + let y = _mm_loadu_ps(b); + a = a.add(4); + b = b.add(4); + n -= 4; + let d = _mm_sub_ps(x, y); + d2 = _mm_fmadd_ps(d, d, d2); } - if n > 0 { - let mask = _bzhi_u32(0xffff, n) as u16; - let x = _mm512_maskz_loadu_ps(mask, a); - let y = _mm512_maskz_loadu_ps(mask, b); - let d = _mm512_sub_ps(x, y); - d2 = _mm512_fmadd_ps(d, d, d2); + let mut d2 = emulate_mm_reduce_add_ps(d2); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let y = b.read(); + a = a.add(1); + b = b.add(1); + n -= 1; + let d = x - y; + d2 += d * d; } - _mm512_reduce_add_ps(d2) + d2 } } #[cfg(all(target_arch = "x86_64", test))] #[test] - fn reduce_sum_of_d2_v4_test() { + fn reduce_sum_of_d2_v2_fma_test() { use rand::Rng; const EPSILON: f32 = 0.02; - detect::init(); - if !detect::v4::detect() { - println!("test {} ... skipped (v4)", module_path!()); + if !crate::simd::is_cpu_detected!("v2") || !crate::simd::is_feature_detected!("fma") { + println!("test {} ... skipped (v2:fma)", module_path!()); return; } let mut rng = rand::thread_rng(); @@ -926,8 +1793,8 @@ mod reduce_sum_of_d2 { for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_d2_v4(lhs, rhs) }; - let fallback = unsafe { reduce_sum_of_d2_fallback(lhs, rhs) }; + let specialized = unsafe { reduce_sum_of_d2_v2_fma(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -936,36 +1803,27 @@ mod reduce_sum_of_d2 { } } - #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - unsafe fn reduce_sum_of_d2_v3(lhs: &[f32], rhs: &[f32]) -> f32 { - use crate::simd::emulate::emulate_mm256_reduce_add_ps; + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_d2_v8_3a(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); unsafe { - use std::arch::x86_64::*; + use std::arch::aarch64::*; let mut n = lhs.len(); let mut a = lhs.as_ptr(); let mut b = rhs.as_ptr(); - let mut d2 = _mm256_setzero_ps(); - while n >= 8 { - let x = _mm256_loadu_ps(a); - let y = _mm256_loadu_ps(b); - a = a.add(8); - b = b.add(8); - n -= 8; - let d = _mm256_sub_ps(x, y); - d2 = _mm256_fmadd_ps(d, d, d2); - } - if n >= 4 { - let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); - let y = _mm256_zextps128_ps256(_mm_loadu_ps(b)); + let mut d2 = vdupq_n_f32(0.0); + while n >= 4 { + let x = vld1q_f32(a); + let y = vld1q_f32(b); a = a.add(4); b = b.add(4); n -= 4; - let d = _mm256_sub_ps(x, y); - d2 = _mm256_fmadd_ps(d, d, d2); + let d = vsubq_f32(x, y); + d2 = vfmaq_f32(d2, d, d); } - let mut d2 = emulate_mm256_reduce_add_ps(d2); + let mut d2 = vaddvq_f32(d2); // this hint is used to disable loop unrolling while std::hint::black_box(n) > 0 { let x = a.read(); @@ -980,14 +1838,13 @@ mod reduce_sum_of_d2 { } } - #[cfg(all(target_arch = "x86_64", test))] + #[cfg(all(target_arch = "aarch64", test))] #[test] - fn reduce_sum_of_d2_v3_test() { + fn reduce_sum_of_d2_v8_3a_test() { use rand::Rng; const EPSILON: f32 = 0.02; - detect::init(); - if !detect::v3::detect() { - println!("test {} ... skipped (v3)", module_path!()); + if !crate::simd::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); return; } let mut rng = rand::thread_rng(); @@ -1002,8 +1859,53 @@ mod reduce_sum_of_d2 { for z in 3984..4016 { let lhs = &lhs[..z]; let rhs = &rhs[..z]; - let specialized = unsafe { reduce_sum_of_d2_v3(lhs, rhs) }; - let fallback = unsafe { reduce_sum_of_d2_fallback(lhs, rhs) }; + let specialized = unsafe { reduce_sum_of_d2_v8_3a(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_sum_of_d2_v8_3a_sve(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp32_reduce_sum_of_d2_v8_3a_sve(a: *const f32, b: *const f32, n: usize) -> f32; + } + fp32_reduce_sum_of_d2_v8_3a_sve(lhs.as_ptr(), rhs.as_ptr(), lhs.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_d2_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 0.02; + if !crate::simd::is_cpu_detected!("v8.3a") || !crate::simd::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v8_3a_sve(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1012,7 +1914,7 @@ mod reduce_sum_of_d2 { } } - #[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a:sve", @"v8.3a")] pub fn reduce_sum_of_d2(lhs: &[f32], rhs: &[f32]) -> f32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -1028,8 +1930,8 @@ mod reduce_sum_of_d2 { mod reduce_sum_of_sparse_xy { #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn reduce_sum_of_sparse_xy_v4(li: &[u32], lv: &[f32], ri: &[u32], rv: &[f32]) -> f32 { + #[crate::simd::target_cpu(enable = "v4")] + fn reduce_sum_of_sparse_xy_v4(li: &[u32], lv: &[f32], ri: &[u32], rv: &[f32]) -> f32 { use crate::simd::emulate::emulate_mm512_2intersect_epi32; assert_eq!(li.len(), lv.len()); assert_eq!(ri.len(), rv.len()); @@ -1077,8 +1979,7 @@ mod reduce_sum_of_sparse_xy { fn reduce_sum_of_sparse_xy_v4_test() { use rand::Rng; const EPSILON: f32 = 0.000001; - detect::init(); - if !detect::v4::detect() { + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -1095,7 +1996,7 @@ mod reduce_sum_of_sparse_xy { .map(|_| rng.gen_range(-1.0..=1.0)) .collect::>(); let specialized = unsafe { reduce_sum_of_sparse_xy_v4(&lidx, &lval, &ridx, &rval) }; - let fallback = unsafe { reduce_sum_of_sparse_xy_fallback(&lidx, &lval, &ridx, &rval) }; + let fallback = reduce_sum_of_sparse_xy_fallback(&lidx, &lval, &ridx, &rval); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1103,7 +2004,7 @@ mod reduce_sum_of_sparse_xy { } } - #[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4", "v3", "v2", "v8.3a:sve", "v8.3a")] pub fn reduce_sum_of_sparse_xy(lidx: &[u32], lval: &[f32], ridx: &[u32], rval: &[f32]) -> f32 { use std::cmp::Ordering; assert_eq!(lidx.len(), lval.len()); @@ -1133,8 +2034,8 @@ mod reduce_sum_of_sparse_xy { mod reduce_sum_of_sparse_d2 { #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn reduce_sum_of_sparse_d2_v4(li: &[u32], lv: &[f32], ri: &[u32], rv: &[f32]) -> f32 { + #[crate::simd::target_cpu(enable = "v4")] + fn reduce_sum_of_sparse_d2_v4(li: &[u32], lv: &[f32], ri: &[u32], rv: &[f32]) -> f32 { use crate::simd::emulate::emulate_mm512_2intersect_epi32; assert_eq!(li.len(), lv.len()); assert_eq!(ri.len(), rv.len()); @@ -1216,8 +2117,7 @@ mod reduce_sum_of_sparse_d2 { fn reduce_sum_of_sparse_d2_v4_test() { use rand::Rng; const EPSILON: f32 = 0.0004; - detect::init(); - if !detect::v4::detect() { + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -1234,7 +2134,7 @@ mod reduce_sum_of_sparse_d2 { .map(|_| rng.gen_range(-1.0..=1.0)) .collect::>(); let specialized = unsafe { reduce_sum_of_sparse_d2_v4(&lidx, &lval, &ridx, &rval) }; - let fallback = unsafe { reduce_sum_of_sparse_d2_fallback(&lidx, &lval, &ridx, &rval) }; + let fallback = reduce_sum_of_sparse_d2_fallback(&lidx, &lval, &ridx, &rval); assert!( (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." @@ -1242,7 +2142,7 @@ mod reduce_sum_of_sparse_d2 { } } - #[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4", "v3", "v2", "v8.3a:sve", "v8.3a")] pub fn reduce_sum_of_sparse_d2(lidx: &[u32], lval: &[f32], ridx: &[u32], rval: &[f32]) -> f32 { use std::cmp::Ordering; assert_eq!(lidx.len(), lval.len()); @@ -1277,3 +2177,119 @@ mod reduce_sum_of_sparse_d2 { d2 } } + +mod vector_add { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_add(lhs: &[f32], rhs: &[f32]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] + rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_add_inplace { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_add_inplace(lhs: &mut [f32], rhs: &[f32]) { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + for i in 0..n { + lhs[i] += rhs[i]; + } + } +} + +mod vector_sub { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_sub(lhs: &[f32], rhs: &[f32]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] - rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul(lhs: &[f32], rhs: &[f32]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] * rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul_scalar { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul_scalar(lhs: &[f32], rhs: f32) -> Vec { + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] * rhs); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul_scalar_inplace { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul_scalar_inplace(lhs: &mut [f32], rhs: f32) { + let n = lhs.len(); + for i in 0..n { + lhs[i] *= rhs; + } + } +} + +mod vector_abs_inplace { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_abs_inplace(this: &mut [f32]) { + let n = this.len(); + for i in 0..n { + this[i] = this[i].abs(); + } + } +} + +mod kmeans_helper { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn kmeans_helper(this: &mut [f32], x: f32, y: f32) { + let n = this.len(); + for i in 0..n { + if i % 2 == 0 { + this[i] *= x; + } else { + this[i] *= y; + } + } + } +} diff --git a/crates/base/src/simd/fast_scan/b4.rs b/crates/base/src/simd/fast_scan/b4.rs index 454a0b1db..8476c8cde 100644 --- a/crates/base/src/simd/fast_scan/b4.rs +++ b/crates/base/src/simd/fast_scan/b4.rs @@ -59,9 +59,10 @@ pub fn pack(width: u32, r: [Vec; 32]) -> impl Iterator { } mod fast_scan_b4 { + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn fast_scan_b4_v4(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + #[crate::simd::target_cpu(enable = "v4")] + fn fast_scan_b4_v4(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually assert_eq!(codes.len(), width as usize * 16); assert_eq!(lut.len(), width as usize * 16); @@ -70,8 +71,8 @@ mod fast_scan_b4 { use std::arch::x86_64::*; #[inline] - #[detect::target_cpu(enable = "v4")] - unsafe fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { + #[crate::simd::target_cpu(enable = "v4")] + fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { unsafe { let x1y0 = _mm256_permute2f128_si256(x0x1, y0y1, 0x21); let x0y1 = _mm256_blend_epi32(x0x1, y0y1, 0xf0); @@ -80,8 +81,8 @@ mod fast_scan_b4 { } #[inline] - #[detect::target_cpu(enable = "v4")] - unsafe fn combine4x2(x0x1x2x3: __m512i, y0y1y2y3: __m512i) -> __m256i { + #[crate::simd::target_cpu(enable = "v4")] + fn combine4x2(x0x1x2x3: __m512i, y0y1y2y3: __m512i) -> __m256i { unsafe { let x0x1 = _mm512_castsi512_si256(x0x1x2x3); let x2x3 = _mm512_extracti64x4_epi64(x0x1x2x3, 1); @@ -173,8 +174,7 @@ mod fast_scan_b4 { #[cfg(target_arch = "x86_64")] #[test] fn fast_scan_b4_v4_test() { - detect::init(); - if !detect::v4::detect() { + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -192,9 +192,10 @@ mod fast_scan_b4 { } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - unsafe fn fast_scan_b4_v3(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + #[crate::simd::target_cpu(enable = "v3")] + fn fast_scan_b4_v3(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually assert_eq!(codes.len(), width as usize * 16); assert_eq!(lut.len(), width as usize * 16); @@ -203,8 +204,8 @@ mod fast_scan_b4 { use std::arch::x86_64::*; #[inline] - #[detect::target_cpu(enable = "v3")] - unsafe fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { + #[crate::simd::target_cpu(enable = "v3")] + fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { unsafe { let x1y0 = _mm256_permute2f128_si256(x0x1, y0y1, 0x21); let x0y1 = _mm256_blend_epi32(x0x1, y0y1, 0xf0); @@ -275,8 +276,7 @@ mod fast_scan_b4 { #[cfg(target_arch = "x86_64")] #[test] fn fast_scan_b4_v3_test() { - detect::init(); - if !detect::v3::detect() { + if !crate::simd::is_cpu_detected!("v3") { println!("test {} ... skipped (v3)", module_path!()); return; } @@ -295,8 +295,8 @@ mod fast_scan_b4 { } #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v2")] - unsafe fn fast_scan_b4_v2(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + #[crate::simd::target_cpu(enable = "v2")] + fn fast_scan_b4_v2(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually assert_eq!(codes.len(), width as usize * 16); assert_eq!(lut.len(), width as usize * 16); @@ -346,8 +346,7 @@ mod fast_scan_b4 { #[cfg(target_arch = "x86_64")] #[test] fn fast_scan_b4_v2_test() { - detect::init(); - if !detect::v2::detect() { + if !crate::simd::is_cpu_detected!("v2") { println!("test {} ... skipped (v2)", module_path!()); return; } @@ -365,7 +364,77 @@ mod fast_scan_b4 { } } - #[detect::multiversion(v4 = import, v3 = import, v2 = import, neon, fallback = export)] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + fn fast_scan_b4_v8_3a(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + // bounds checking is not enforced by compiler, so check it manually + assert_eq!(codes.len(), width as usize * 16); + assert_eq!(lut.len(), width as usize * 16); + + unsafe { + use std::arch::aarch64::*; + + let mut accu_0 = vdupq_n_u16(0); + let mut accu_1 = vdupq_n_u16(0); + let mut accu_2 = vdupq_n_u16(0); + let mut accu_3 = vdupq_n_u16(0); + + let mut i = 0_usize; + while i < width as usize { + let c = vld1q_u8(codes.as_ptr().add(i * 16).cast()); + + let mask = vdupq_n_u8(0xf); + let clo = vandq_u8(c, mask); + let chi = vandq_u8(vshrq_n_u8(c, 4), mask); + + let lut = vld1q_u8(lut.as_ptr().add(i * 16).cast()); + let res_lo = vreinterpretq_u16_u8(vqtbl1q_u8(lut, clo)); + accu_0 = vaddq_u16(accu_0, res_lo); + accu_1 = vaddq_u16(accu_1, vshrq_n_u16(res_lo, 8)); + let res_hi = vreinterpretq_u16_u8(vqtbl1q_u8(lut, chi)); + accu_2 = vaddq_u16(accu_2, res_hi); + accu_3 = vaddq_u16(accu_3, vshrq_n_u16(res_hi, 8)); + + i += 1; + } + debug_assert_eq!(i, width as usize); + + let mut result = [0_u16; 32]; + + accu_0 = vsubq_u16(accu_0, vshlq_n_u16(accu_1, 8)); + vst1q_u16(result.as_mut_ptr().add(0).cast(), accu_0); + vst1q_u16(result.as_mut_ptr().add(8).cast(), accu_1); + + accu_2 = vsubq_u16(accu_2, vshlq_n_u16(accu_3, 8)); + vst1q_u16(result.as_mut_ptr().add(16).cast(), accu_2); + vst1q_u16(result.as_mut_ptr().add(24).cast(), accu_3); + + result + } + } + + #[cfg(target_arch = "aarch64")] + #[test] + fn fast_scan_b4_v8_3a_test() { + if !crate::simd::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + for _ in 0..200 { + for width in 90..110 { + let codes = (0..16 * width).map(|_| rand::random()).collect::>(); + let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + unsafe { + assert_eq!( + fast_scan_b4_v8_3a(width, &codes, &lut), + fast_scan_b4_fallback(width, &codes, &lut) + ); + } + } + } + } + + #[crate::simd::multiversion(@"v4", @"v3", @"v2", @"v8.3a")] pub fn fast_scan_b4(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { let width = width as usize; diff --git a/crates/base/src/simd/mod.rs b/crates/base/src/simd/mod.rs index c94bdcc1c..406a067d6 100644 --- a/crates/base/src/simd/mod.rs +++ b/crates/base/src/simd/mod.rs @@ -60,6 +60,53 @@ pub trait ScalarLike: fn kmeans_helper(this: &mut [Self], x: f32, y: f32); } -pub fn enable() { - detect::init(); +#[allow(clippy::crate_in_macro_def)] +mod internal { + #[cfg(target_arch = "x86_64")] + base_macros::define_is_cpu_detected!("x86_64"); + + #[cfg(target_arch = "aarch64")] + base_macros::define_is_cpu_detected!("aarch64"); + + #[cfg(target_arch = "riscv64")] + base_macros::define_is_cpu_detected!("riscv64"); + + #[cfg(target_arch = "x86_64")] + #[allow(unused_imports)] + pub use is_x86_64_cpu_detected; + + #[cfg(target_arch = "aarch64")] + #[allow(unused_imports)] + pub use is_aarch64_cpu_detected; + + #[cfg(target_arch = "riscv64")] + #[allow(unused_imports)] + pub use is_riscv64_cpu_detected; } + +pub use base_macros::multiversion; +pub use base_macros::target_cpu; + +#[cfg(target_arch = "x86_64")] +#[allow(unused_imports)] +pub use std::arch::is_x86_feature_detected as is_feature_detected; + +#[cfg(target_arch = "aarch64")] +#[allow(unused_imports)] +pub use std::arch::is_aarch64_feature_detected as is_feature_detected; + +#[cfg(target_arch = "riscv64")] +#[allow(unused_imports)] +pub use std::arch::is_riscv_feature_detected as is_feature_detected; + +#[cfg(target_arch = "x86_64")] +#[allow(unused_imports)] +pub use internal::is_x86_64_cpu_detected as is_cpu_detected; + +#[cfg(target_arch = "aarch64")] +#[allow(unused_imports)] +pub use internal::is_aarch64_cpu_detected as is_cpu_detected; + +#[cfg(target_arch = "riscv64")] +#[allow(unused_imports)] +pub use internal::is_riscv64_cpu_detected as is_cpu_detected; diff --git a/crates/base/src/simd/packed_u4.rs b/crates/base/src/simd/packed_u4.rs index 24d3c522c..3273147de 100644 --- a/crates/base/src/simd/packed_u4.rs +++ b/crates/base/src/simd/packed_u4.rs @@ -1,12 +1,18 @@ -#[detect::multiversion(v4, v3, v2, fallback)] pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { - assert_eq!(s.len(), t.len()); - let n = s.len(); - let mut result = 0; - for i in 0..n { - let (s, t) = (s[i], t[i]); - result += ((s & 15) as u32) * ((t & 15) as u32); - result += ((s >> 4) as u32) * ((t >> 4) as u32); + reduce_sum_of_xy::reduce_sum_of_xy(s, t) +} + +mod reduce_sum_of_xy { + #[crate::simd::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { + assert_eq!(s.len(), t.len()); + let n = s.len(); + let mut result = 0; + for i in 0..n { + let (s, t) = (s[i], t[i]); + result += ((s & 15) as u32) * ((t & 15) as u32); + result += ((s >> 4) as u32) * ((t >> 4) as u32); + } + result } - result } diff --git a/crates/base/src/simd/quantize.rs b/crates/base/src/simd/quantize.rs index 62170f8fe..a97d4336f 100644 --- a/crates/base/src/simd/quantize.rs +++ b/crates/base/src/simd/quantize.rs @@ -1,9 +1,10 @@ use crate::simd::*; mod mul_add_round { + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn mul_add_round_v4(this: &[f32], k: f32, b: f32) -> Vec { + #[crate::simd::target_cpu(enable = "v4")] + fn mul_add_round_v4(this: &[f32], k: f32, b: f32) -> Vec { let n = this.len(); let mut r = Vec::::with_capacity(n); unsafe { @@ -43,8 +44,7 @@ mod mul_add_round { #[cfg(all(target_arch = "x86_64", test))] #[test] fn mul_add_round_v4_test() { - detect::init(); - if !detect::v4::detect() { + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -56,15 +56,16 @@ mod mul_add_round { let k = 20.0; let b = 20.0; let specialized = unsafe { mul_add_round_v4(x, k, b) }; - let fallback = unsafe { mul_add_round_fallback(x, k, b) }; + let fallback = mul_add_round_fallback(x, k, b); assert_eq!(specialized, fallback); } } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - unsafe fn mul_add_round_v3(this: &[f32], k: f32, b: f32) -> Vec { + #[crate::simd::target_cpu(enable = "v3")] + fn mul_add_round_v3(this: &[f32], k: f32, b: f32) -> Vec { let n = this.len(); let mut r = Vec::::with_capacity(n); unsafe { @@ -112,8 +113,7 @@ mod mul_add_round { #[cfg(all(target_arch = "x86_64", test))] #[test] fn mul_add_round_v3_test() { - detect::init(); - if !detect::v3::detect() { + if !crate::simd::is_cpu_detected!("v3") { println!("test {} ... skipped (v3)", module_path!()); return; } @@ -125,13 +125,147 @@ mod mul_add_round { let k = 20.0; let b = 20.0; let specialized = unsafe { mul_add_round_v3(x, k, b) }; - let fallback = unsafe { mul_add_round_fallback(x, k, b) }; + let fallback = mul_add_round_fallback(x, k, b); assert_eq!(specialized, fallback); } } } - #[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)] + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v2")] + #[target_feature(enable = "fma")] + fn mul_add_round_v2_fma(this: &[f32], k: f32, b: f32) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + unsafe { + use std::arch::x86_64::*; + let cons = _mm_setr_epi8( + 0, 4, 8, 12, -1, -1, -1, -1, // 0..8 + -1, -1, -1, -1, -1, -1, -1, -1, // 8..15 + ); + let lk = _mm_set1_ps(k); + let lb = _mm_set1_ps(b); + let mut n = n; + let mut a = this.as_ptr(); + let mut r = r.as_mut_ptr(); + while n >= 4 { + let x = _mm_loadu_ps(a); + let v = _mm_fmadd_ps(x, lk, lb); + let v = _mm_cvtps_epi32(_mm_round_ps(v, 0x00)); + let vs = _mm_shuffle_epi8(v, cons); + let vfl = _mm_extract_epi32::<0>(vs) as u32; + r.cast::().write_unaligned(vfl); + n -= 4; + a = a.add(4); + r = r.add(4); + } + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let v = x.mul_add(k, b).round_ties_even() as u8; + r.write(v); + n -= 1; + a = a.add(1); + r = r.add(1); + } + } + unsafe { + r.set_len(n); + } + r + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn mul_add_round_v2_fma_test() { + if !crate::simd::is_cpu_detected!("v2") || !crate::simd::is_feature_detected!("fma") { + println!("test {} ... skipped (v2:fma)", module_path!()); + return; + } + for _ in 0..300 { + let n = 4010; + let x = (0..n).map(|_| rand::random::<_>()).collect::>(); + for z in 3990..4010 { + let x = &x[..z]; + let k = 20.0; + let b = 20.0; + let specialized = unsafe { mul_add_round_v2_fma(x, k, b) }; + let fallback = mul_add_round_fallback(x, k, b); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + fn mul_add_round_v8_3a(this: &[f32], k: f32, b: f32) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + unsafe { + use std::arch::aarch64::*; + let cons = vld1q_u8( + [ + 0, 4, 8, 12, 0xff, 0xff, 0xff, 0xff, // 0..8 + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, // 8..15 + ] + .as_ptr(), + ); + let lk = vdupq_n_f32(k); + let lb = vdupq_n_f32(b); + let mut n = n; + let mut a = this.as_ptr(); + let mut r = r.as_mut_ptr(); + while n >= 4 { + let x = vld1q_f32(a); + let v = vfmaq_f32(lb, x, lk); + let v = vcvtnq_u32_f32(v); + let vs = vqtbl1q_u8(vreinterpretq_u8_u32(v), cons); + let vfl = vgetq_lane_u32::<0>(vreinterpretq_u32_u8(vs)); + r.cast::().write_unaligned(vfl); + n -= 4; + a = a.add(4); + r = r.add(4); + } + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let v = x.mul_add(k, b).round_ties_even() as u8; + r.write(v); + n -= 1; + a = a.add(1); + r = r.add(1); + } + } + unsafe { + r.set_len(n); + } + r + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn mul_add_round_v8_3a_test() { + if !crate::simd::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + for _ in 0..300 { + let n = 4010; + let x = (0..n).map(|_| rand::random::<_>()).collect::>(); + for z in 3990..4010 { + let x = &x[..z]; + let k = 20.0; + let b = 20.0; + let specialized = unsafe { mul_add_round_v8_3a(x, k, b) }; + let fallback = mul_add_round_fallback(x, k, b); + assert_eq!(specialized, fallback); + } + } + } + + #[crate::simd::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a")] pub fn mul_add_round(this: &[f32], k: f32, b: f32) -> Vec { let n = this.len(); let mut r = Vec::::with_capacity(n); diff --git a/crates/base/src/simd/u8.rs b/crates/base/src/simd/u8.rs index 5c838776b..928478fe5 100644 --- a/crates/base/src/simd/u8.rs +++ b/crates/base/src/simd/u8.rs @@ -1,6 +1,6 @@ use crate::simd::*; -#[detect::multiversion(v4, v3, v2, neon, fallback)] +#[multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { assert_eq!(s.len(), t.len()); let n = s.len(); @@ -12,9 +12,10 @@ pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { } mod reduce_sum_of_x_as_u16 { + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn reduce_sum_of_x_as_u16_v4(this: &[u8]) -> u16 { + #[crate::simd::target_cpu(enable = "v4")] + fn reduce_sum_of_x_as_u16_v4(this: &[u8]) -> u16 { use crate::simd::emulate::emulate_mm512_reduce_add_epi16; unsafe { use std::arch::x86_64::*; @@ -41,8 +42,7 @@ mod reduce_sum_of_x_as_u16 { #[test] fn reduce_sum_of_x_as_u16_v4_test() { use rand::Rng; - detect::init(); - if !detect::v4::detect() { + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -53,15 +53,16 @@ mod reduce_sum_of_x_as_u16 { for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x_as_u16_v4(this) }; - let fallback = unsafe { reduce_sum_of_x_as_u16_fallback(this) }; + let fallback = reduce_sum_of_x_as_u16_fallback(this); assert_eq!(specialized, fallback); } } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - unsafe fn reduce_sum_of_x_as_u16_v3(this: &[u8]) -> u16 { + #[crate::simd::target_cpu(enable = "v3")] + fn reduce_sum_of_x_as_u16_v3(this: &[u8]) -> u16 { use crate::simd::emulate::emulate_mm256_reduce_add_epi16; unsafe { use std::arch::x86_64::*; @@ -91,8 +92,7 @@ mod reduce_sum_of_x_as_u16 { #[test] fn reduce_sum_of_x_as_u16_v3_test() { use rand::Rng; - detect::init(); - if !detect::v3::detect() { + if !crate::simd::is_cpu_detected!("v3") { println!("test {} ... skipped (v3)", module_path!()); return; } @@ -103,13 +103,112 @@ mod reduce_sum_of_x_as_u16 { for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x_as_u16_v3(this) }; - let fallback = unsafe { reduce_sum_of_x_as_u16_fallback(this) }; + let fallback = reduce_sum_of_x_as_u16_fallback(this); assert_eq!(specialized, fallback); } } } - #[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)] + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::simd::target_cpu(enable = "v2")] + fn reduce_sum_of_x_as_u16_v2(this: &[u8]) -> u16 { + use crate::simd::emulate::emulate_mm_reduce_add_epi16; + unsafe { + use std::arch::x86_64::*; + let us = _mm_set1_epi16(255); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm_setzero_si128(); + while n >= 8 { + let x = _mm_loadu_si64(a.cast()); + a = a.add(8); + n -= 8; + sum = _mm_add_epi16(_mm_and_si128(us, _mm_cvtepi8_epi16(x)), sum); + } + let mut sum = emulate_mm_reduce_add_epi16(sum) as u16; + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x as u16; + } + sum + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x_as_u16_v2_test() { + use rand::Rng; + if !crate::simd::is_cpu_detected!("v2") { + println!("test {} ... skipped (v2)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_as_u16_v2(this) }; + let fallback = reduce_sum_of_x_as_u16_fallback(this); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::simd::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_x_as_u16_v8_3a(this: &[u8]) -> u16 { + unsafe { + use std::arch::aarch64::*; + let us = vdupq_n_u16(255); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = vdupq_n_u16(0); + while n >= 8 { + let x = vld1_u8(a); + a = a.add(8); + n -= 8; + sum = vaddq_u16(vandq_u16(us, vmovl_u8(x)), sum); + } + let mut sum = vaddvq_u16(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x as u16; + } + sum + } + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_sum_of_x_as_u16_v8_3a_test() { + use rand::Rng; + if !crate::simd::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..256 { + let n = 4016; + let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_as_u16_v8_3a(this) }; + let fallback = reduce_sum_of_x_as_u16_fallback(this); + assert_eq!(specialized, fallback); + } + } + } + + #[crate::simd::multiversion(@"v4", @"v3", @"v2", @"v8.3a")] pub fn reduce_sum_of_x_as_u16(this: &[u8]) -> u16 { let n = this.len(); let mut sum = 0; @@ -126,9 +225,10 @@ pub fn reduce_sum_of_x_as_u16(vector: &[u8]) -> u16 { } mod reduce_sum_of_x_as_u32 { + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4")] - unsafe fn reduce_sum_of_x_as_u32_v4(this: &[u8]) -> u32 { + #[crate::simd::target_cpu(enable = "v4")] + fn reduce_sum_of_x_as_u32_v4(this: &[u8]) -> u32 { unsafe { use std::arch::x86_64::*; let us = _mm512_set1_epi32(255); @@ -154,8 +254,7 @@ mod reduce_sum_of_x_as_u32 { #[test] fn reduce_sum_of_x_as_u32_v4_test() { use rand::Rng; - detect::init(); - if !detect::v4::detect() { + if !crate::simd::is_cpu_detected!("v4") { println!("test {} ... skipped (v4)", module_path!()); return; } @@ -166,15 +265,16 @@ mod reduce_sum_of_x_as_u32 { for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x_as_u32_v4(this) }; - let fallback = unsafe { reduce_sum_of_x_as_u32_fallback(this) }; + let fallback = reduce_sum_of_x_as_u32_fallback(this); assert_eq!(specialized, fallback); } } } + #[inline] #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v3")] - unsafe fn reduce_sum_of_x_as_u32_v3(this: &[u8]) -> u32 { + #[crate::simd::target_cpu(enable = "v3")] + fn reduce_sum_of_x_as_u32_v3(this: &[u8]) -> u32 { use crate::simd::emulate::emulate_mm256_reduce_add_epi32; unsafe { use std::arch::x86_64::*; @@ -204,8 +304,7 @@ mod reduce_sum_of_x_as_u32 { #[test] fn reduce_sum_of_x_as_u16_v3_test() { use rand::Rng; - detect::init(); - if !detect::v3::detect() { + if !crate::simd::is_cpu_detected!("v3") { println!("test {} ... skipped (v3)", module_path!()); return; } @@ -216,13 +315,13 @@ mod reduce_sum_of_x_as_u32 { for z in 3984..4016 { let this = &this[..z]; let specialized = unsafe { reduce_sum_of_x_as_u32_v3(this) }; - let fallback = unsafe { reduce_sum_of_x_as_u32_fallback(this) }; + let fallback = reduce_sum_of_x_as_u32_fallback(this); assert_eq!(specialized, fallback); } } } - #[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)] + #[crate::simd::multiversion(@"v4", @"v3", "v2", "v8.3a:sve", "v8.3a")] pub fn reduce_sum_of_x_as_u32(this: &[u8]) -> u32 { let n = this.len(); let mut sum = 0; diff --git a/crates/base_macros/src/lib.rs b/crates/base_macros/src/lib.rs index 31e726585..42496a75c 100644 --- a/crates/base_macros/src/lib.rs +++ b/crates/base_macros/src/lib.rs @@ -1,3 +1,5 @@ +mod target; + use quote::quote; use syn::{parse_macro_input, Data, DeriveInput, Fields}; @@ -75,3 +77,208 @@ pub fn alter(input: proc_macro::TokenStream) -> proc_macro::TokenStream { Data::Union(_) => panic!("union is not supported"), } } + +struct MultiversionVersion { + target: String, + import: bool, +} + +impl syn::parse::Parse for MultiversionVersion { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead1 = input.lookahead1(); + if lookahead1.peek(syn::Token![@]) { + let _: syn::Token![@] = input.parse()?; + let target: syn::LitStr = input.parse()?; + Ok(Self { + target: target.value(), + import: true, + }) + } else { + let target: syn::LitStr = input.parse()?; + Ok(Self { + target: target.value(), + import: false, + }) + } + } +} + +struct Multiversion { + versions: syn::punctuated::Punctuated, +} + +impl syn::parse::Parse for Multiversion { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Multiversion { + versions: syn::punctuated::Punctuated::parse_terminated(input)?, + }) + } +} + +#[proc_macro_attribute] +pub fn multiversion( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let attr = syn::parse_macro_input!(attr as Multiversion); + let item_fn = syn::parse::(item).expect("not a function item"); + let syn::ItemFn { + attrs, + vis, + sig, + block, + } = item_fn; + let name = sig.ident.to_string(); + if sig.constness.is_some() { + panic!("const functions are not supported"); + } + if sig.asyncness.is_some() { + panic!("async functions are not supported"); + } + let generics_params = sig.generics.params.clone(); + for generic_param in generics_params.iter() { + if !matches!(generic_param, syn::GenericParam::Lifetime(_)) { + panic!("generic parameters are not supported"); + } + } + let generics_where = sig.generics.where_clause.clone(); + let inputs = sig.inputs.clone(); + let arguments = { + let mut list = vec![]; + for x in sig.inputs.iter() { + if let syn::FnArg::Typed(y) = x { + if let syn::Pat::Ident(ident) = *y.pat.clone() { + list.push(ident); + } else { + panic!("patterns on parameters are not supported") + } + } else { + panic!("receiver parameters are not supported") + } + } + list + }; + if sig.variadic.is_some() { + panic!("variadic parameters are not supported"); + } + let output = sig.output.clone(); + let mut versions = quote::quote! {}; + let mut branches = quote::quote! {}; + for version in attr.versions { + let target = version.target.clone(); + let name = syn::Ident::new( + &format!("{name}_{}", target.replace(":", "_").replace(".", "_")), + proc_macro2::Span::mixed_site(), + ); + let s = target.split(":").collect::>(); + let target_cpu = target::TARGET_CPUS + .iter() + .find(|target_cpu| target_cpu.target_cpu == s[0]) + .expect("unknown target_cpu"); + let additional_target_features = s[1..].to_vec(); + let target_arch = target_cpu.target_arch; + let target_cpu = target_cpu.target_cpu; + if !version.import { + versions.extend(quote::quote! { + #[inline] + #[cfg(any(target_arch = #target_arch))] + #[crate::simd::target_cpu(enable = #target_cpu)] + #(#[target_feature(enable = #additional_target_features)])* + fn #name < #generics_params > (#inputs) #output #generics_where { #block } + }); + } + branches.extend(quote::quote! { + #[cfg(target_arch = #target_arch)] + if crate::simd::is_cpu_detected!(#target_cpu) #(&& crate::simd::is_feature_detected!(#additional_target_features))* { + let _multiversion_internal: unsafe fn(#inputs) #output = #name; + CACHE.store(_multiversion_internal as *mut (), core::sync::atomic::Ordering::Relaxed); + return unsafe { _multiversion_internal(#(#arguments,)*) }; + } + }); + } + let fallback_name = + syn::Ident::new(&format!("{name}_fallback"), proc_macro2::Span::mixed_site()); + quote::quote! { + #versions + fn #fallback_name < #generics_params > (#inputs) #output #generics_where { #block } + #[inline(always)] + #(#attrs)* #vis #sig { + static CACHE: core::sync::atomic::AtomicPtr<()> = core::sync::atomic::AtomicPtr::new(core::ptr::null_mut()); + let cache = CACHE.load(core::sync::atomic::Ordering::Relaxed); + if !cache.is_null() { + let f = unsafe { core::mem::transmute::<*mut (), unsafe fn(#inputs) #output>(cache as _) }; + return unsafe { f(#(#arguments,)*) }; + } + #branches + let _multiversion_internal: unsafe fn(#inputs) #output = #fallback_name; + CACHE.store(_multiversion_internal as *mut (), core::sync::atomic::Ordering::Relaxed); + unsafe { _multiversion_internal(#(#arguments,)*) } + } + } + .into() +} + +struct TargetCpu { + enable: String, +} + +impl syn::parse::Parse for TargetCpu { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let _: syn::Ident = input.parse()?; + let _: syn::Token![=] = input.parse()?; + let enable: syn::LitStr = input.parse()?; + Ok(Self { + enable: enable.value(), + }) + } +} + +#[proc_macro_attribute] +pub fn target_cpu( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let attr = syn::parse_macro_input!(attr as TargetCpu); + let mut result = quote::quote! {}; + for s in attr.enable.split(',') { + let target_cpu = target::TARGET_CPUS + .iter() + .find(|target_cpu| target_cpu.target_cpu == s) + .expect("unknown target_cpu"); + let target_features = target_cpu.target_features; + result.extend(quote::quote!( + #(#[target_feature(enable = #target_features)])* + )); + } + result.extend(proc_macro2::TokenStream::from(item)); + result.into() +} + +#[proc_macro] +pub fn define_is_cpu_detected(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let target_arch = syn::parse_macro_input!(input as syn::LitStr).value(); + let mut arms = quote::quote! {}; + for target_cpu in target::TARGET_CPUS { + if target_cpu.target_arch != target_arch { + continue; + } + let target_features = target_cpu.target_features; + let target_cpu = target_cpu.target_cpu; + arms.extend(quote::quote! { + (#target_cpu) => { + true #(&& crate::simd::is_feature_detected!(#target_features))* + }; + }); + } + let name = syn::Ident::new( + &format!("is_{target_arch}_cpu_detected"), + proc_macro2::Span::mixed_site(), + ); + quote::quote! { + #[macro_export] + macro_rules! #name { + #arms + } + } + .into() +} diff --git a/crates/base_macros/src/target.rs b/crates/base_macros/src/target.rs new file mode 100644 index 000000000..c2584013b --- /dev/null +++ b/crates/base_macros/src/target.rs @@ -0,0 +1,83 @@ +pub struct TargetCpu { + pub target_cpu: &'static str, + pub target_arch: &'static str, + pub target_features: &'static [&'static str], +} + +pub const TARGET_CPUS: &[TargetCpu] = &[ + TargetCpu { + target_cpu: "v4", + target_arch: "x86_64", + target_features: &[ + "avx", + "avx2", + "avx512bw", + "avx512cd", + "avx512dq", + "avx512f", + "avx512vl", + "bmi1", + "bmi2", + "cmpxchg16b", + "f16c", + "fma", + "fxsr", + "lzcnt", + "movbe", + "popcnt", + "sse", + "sse2", + "sse3", + "sse4.1", + "sse4.2", + "ssse3", + "xsave", + ], + }, + TargetCpu { + target_cpu: "v3", + target_arch: "x86_64", + target_features: &[ + "avx", + "avx2", + "bmi1", + "bmi2", + "cmpxchg16b", + "f16c", + "fma", + "fxsr", + "lzcnt", + "movbe", + "popcnt", + "sse", + "sse2", + "sse3", + "sse4.1", + "sse4.2", + "ssse3", + "xsave", + ], + }, + TargetCpu { + target_cpu: "v2", + target_arch: "x86_64", + target_features: &[ + "cmpxchg16b", + "fxsr", + "popcnt", + "sse", + "sse2", + "sse3", + "sse4.1", + "sse4.2", + "ssse3", + ], + }, + TargetCpu { + target_cpu: "v8.3a", + target_arch: "aarch64", + target_features: &[ + "crc", "dpb", "fcma", "jsconv", "lse", "neon", "paca", "pacg", "rcpc", "rdm", + ], + }, +]; diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index a4a323148..deeb4e394 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -32,7 +32,6 @@ fn calculate_precision(truth: &[i32], res: &[i32], top: usize) -> f32 { } fn main() { - base::simd::enable(); let args: Arguments = argh::from_env(); let path = PathBuf::from_str(&args.path).expect("failed to parse the path"); let mut log_builder = env_logger::builder(); diff --git a/crates/detect/Cargo.toml b/crates/detect/Cargo.toml deleted file mode 100644 index 5f0580873..000000000 --- a/crates/detect/Cargo.toml +++ /dev/null @@ -1,10 +0,0 @@ -[package] -name = "detect" -version.workspace = true -edition.workspace = true - -[dependencies] -detect_macros = { path = "../detect_macros" } - -[lints] -workspace = true diff --git a/crates/detect/src/lib.rs b/crates/detect/src/lib.rs deleted file mode 100644 index 995834b0d..000000000 --- a/crates/detect/src/lib.rs +++ /dev/null @@ -1,46 +0,0 @@ -/// Function multiversioning attribute macros for `pgvecto.rs`. -/// -/// ```no_run -/// #[cfg(target_arch = "x86_64")] -/// #[detect::target_cpu(enable = "v3")] -/// unsafe fn g_v3(x: &[u32]) -> u32 { -/// unimplemented!() -/// } -/// -/// #[cfg(all(target_arch = "x86_64", test))] -/// #[test] -/// fn g_v3_test() { -/// const EPSILON: F32 = F32(1e-5); -/// detect::init(); -/// if !detect::v3::detect() { -/// println!("test {} ... skipped (v3)", module_path!()); -/// return; -/// } -/// let x = vec![0u32; 400]; -/// x.fill_with(|| rand::random()); -/// let specialized = unsafe { g_v3(&x) }; -/// let fallback = unsafe { g_fallback(&x) }; -/// assert!( -/// (specialized - fallback).abs() < EPSILON, -/// "specialized = {specialized}, fallback = {fallback}." -/// ); -/// } -/// -/// // It generates x86_64/v3, x86_64/v2, aarch64/neon and fallback versions of this function. -/// // It takes advantage of `g_v4` as x86_64/v4 version of this function. -/// // It exposes the fallback version with the name "g_fallback". -/// #[detect::multiversion(v3 = import, v2, neon, fallback = export)] -/// fn g(x: &[u32]) -> u32 { -/// let mut result = 0_u32; -/// for v in x { -/// result = result.wrapping_add(*v); -/// } -/// result -/// } -/// ``` -pub use detect_macros::multiversion; - -/// This macros allows you to enable a set of features by target cpu names. -pub use detect_macros::target_cpu; - -detect_macros::main!(); diff --git a/crates/detect_macros/Cargo.toml b/crates/detect_macros/Cargo.toml deleted file mode 100644 index 7826a3edf..000000000 --- a/crates/detect_macros/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "detect_macros" -version.workspace = true -edition.workspace = true - -[lib] -proc-macro = true - -[dependencies] -proc-macro2 = { version = "1.0.79", features = ["proc-macro"] } -quote = "1.0.35" -syn = { version = "2.0.53", default-features = false, features = [ - "clone-impls", - "full", - "parsing", - "printing", - "proc-macro", -] } - -[lints] -workspace = true diff --git a/crates/detect_macros/src/lib.rs b/crates/detect_macros/src/lib.rs deleted file mode 100644 index ea4ebac9f..000000000 --- a/crates/detect_macros/src/lib.rs +++ /dev/null @@ -1,320 +0,0 @@ -struct List { - target_cpu: &'static str, - target_arch: &'static str, - target_features: &'static str, -} - -const LIST: &[List] = &[ - List { - target_cpu: "v4", - target_arch: "x86_64", - target_features: - "avx,avx2,avx512bw,avx512cd,avx512dq,avx512f,avx512vl,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3,xsave" - }, - List { - target_cpu: "v3", - target_arch: "x86_64", - target_features: - "avx,avx2,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3,xsave" - }, - List { - target_cpu: "v2", - target_arch: "x86_64", - target_features: "cmpxchg16b,fxsr,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3", - }, - List { - target_cpu: "neon", - target_arch: "aarch64", - target_features: "neon", - }, - List { - target_cpu: "v4_avx512vpopcntdq", - target_arch: "x86_64", - target_features: - "avx512vpopcntdq,avx,avx2,avx512bw,avx512cd,avx512dq,avx512f,avx512vl,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3,xsave", - }, - List { - target_cpu: "v4_avx512fp16", - target_arch: "x86_64", - target_features: - "avx512fp16,avx,avx2,avx512bw,avx512cd,avx512dq,avx512f,avx512vl,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3,xsave", - }, - List { - target_cpu: "v4_avx512vnni", - target_arch: "x86_64", - target_features: - "avx512vnni,avx,avx2,avx512bw,avx512cd,avx512dq,avx512f,avx512vl,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3,xsave", - }, -]; - -enum MultiversionPort { - Import, - Export, - Hidden, -} - -struct MultiversionVersion { - ident: String, - // Some(false) => import (specialization) - // Some(true) => export - // None => hidden - port: MultiversionPort, -} - -impl syn::parse::Parse for MultiversionVersion { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let ident: syn::Ident = input.parse()?; - let lookahead1 = input.lookahead1(); - if lookahead1.peek(syn::Token![=]) { - let _: syn::Token![=] = input.parse()?; - let p: syn::Ident = input.parse()?; - if p == "import" { - Ok(Self { - ident: ident.to_string(), - port: MultiversionPort::Import, - }) - } else if p == "export" { - Ok(Self { - ident: ident.to_string(), - port: MultiversionPort::Export, - }) - } else if p == "hidden" { - Ok(Self { - ident: ident.to_string(), - port: MultiversionPort::Hidden, - }) - } else { - panic!("unknown port type") - } - } else { - Ok(Self { - ident: ident.to_string(), - port: MultiversionPort::Hidden, - }) - } - } -} - -struct Multiversion { - versions: syn::punctuated::Punctuated, -} - -impl syn::parse::Parse for Multiversion { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - Ok(Multiversion { - versions: syn::punctuated::Punctuated::parse_terminated(input)?, - }) - } -} - -#[proc_macro_attribute] -pub fn multiversion( - attr: proc_macro::TokenStream, - item: proc_macro::TokenStream, -) -> proc_macro::TokenStream { - let attr = syn::parse_macro_input!(attr as Multiversion); - let item_fn = syn::parse::(item).expect("not a function item"); - let syn::ItemFn { - attrs, - vis, - sig, - block, - } = item_fn; - let name = sig.ident.to_string(); - if sig.constness.is_some() { - panic!("const functions are not supported"); - } - if sig.asyncness.is_some() { - panic!("async functions are not supported"); - } - let generics_params = sig.generics.params.clone(); - for generic_param in generics_params.iter() { - if !matches!(generic_param, syn::GenericParam::Lifetime(_)) { - panic!("generic parameters are not supported"); - } - } - let generics_where = sig.generics.where_clause.clone(); - let inputs = sig.inputs.clone(); - let arguments = { - let mut list = vec![]; - for x in sig.inputs.iter() { - if let syn::FnArg::Typed(y) = x { - if let syn::Pat::Ident(ident) = *y.pat.clone() { - list.push(ident); - } else { - panic!("patterns on parameters are not supported") - } - } else { - panic!("receiver parameters are not supported") - } - } - list - }; - if sig.variadic.is_some() { - panic!("variadic parameters are not supported"); - } - let output = sig.output.clone(); - let mut versions_export = quote::quote! {}; - let mut versions_hidden = quote::quote! {}; - let mut branches = quote::quote! {}; - let mut fallback = false; - for version in attr.versions { - let ident = version.ident.clone(); - let name = syn::Ident::new(&format!("{name}_{ident}"), proc_macro2::Span::mixed_site()); - let port; - let branch; - if fallback { - panic!("fallback version is set"); - } else if ident == "fallback" { - fallback = true; - port = quote::quote! { - unsafe fn #name < #generics_params > (#inputs) #output #generics_where { #block } - }; - branch = quote::quote! { - { - let _multiversion_internal: unsafe fn(#inputs) #output = #name; - CACHE.store(_multiversion_internal as *mut (), core::sync::atomic::Ordering::Relaxed); - unsafe { _multiversion_internal(#(#arguments,)*) } - } - }; - } else { - let target_cpu = ident.clone(); - let t = syn::Ident::new(&target_cpu, proc_macro2::Span::mixed_site()); - let list = LIST - .iter() - .find(|list| list.target_cpu == target_cpu) - .expect("unknown target_cpu"); - let target_arch = list.target_arch; - let target_features = list.target_features; - port = quote::quote! { - #[cfg(any(target_arch = #target_arch))] - #[target_feature(enable = #target_features)] - unsafe fn #name < #generics_params > (#inputs) #output #generics_where { #block } - }; - branch = quote::quote! { - #[cfg(target_arch = #target_arch)] - if detect::#t::detect() { - let _multiversion_internal: unsafe fn(#inputs) #output = #name; - CACHE.store(_multiversion_internal as *mut (), core::sync::atomic::Ordering::Relaxed); - return unsafe { _multiversion_internal(#(#arguments,)*) }; - } - }; - } - match version.port { - MultiversionPort::Import => (), - MultiversionPort::Export => versions_export.extend(port), - MultiversionPort::Hidden => versions_hidden.extend(port), - } - branches.extend(branch); - } - if !fallback { - panic!("fallback version is not set"); - } - quote::quote! { - #versions_export - #[inline(always)] - #(#attrs)* #vis #sig { - #versions_hidden - static CACHE: core::sync::atomic::AtomicPtr<()> = core::sync::atomic::AtomicPtr::new(core::ptr::null_mut()); - let cache = CACHE.load(core::sync::atomic::Ordering::Relaxed); - if !cache.is_null() { - let f = unsafe { core::mem::transmute::<*mut (), unsafe fn(#inputs) #output>(cache as _) }; - return unsafe { f(#(#arguments,)*) }; - } - #branches - } - } - .into() -} - -struct TargetCpu { - enable: String, -} - -impl syn::parse::Parse for TargetCpu { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let _: syn::Ident = input.parse()?; - let _: syn::Token![=] = input.parse()?; - let enable: syn::LitStr = input.parse()?; - Ok(Self { - enable: enable.value(), - }) - } -} - -#[proc_macro_attribute] -pub fn target_cpu( - attr: proc_macro::TokenStream, - item: proc_macro::TokenStream, -) -> proc_macro::TokenStream { - let attr = syn::parse_macro_input!(attr as TargetCpu); - let mut result = quote::quote! {}; - for cpu in attr.enable.split(',') { - let list = LIST - .iter() - .find(|list| list.target_cpu == cpu) - .expect("unknown target_cpu"); - let target_features = list.target_features; - result.extend(quote::quote!(#[target_feature(enable = #target_features)])); - } - result.extend(proc_macro2::TokenStream::from(item)); - result.into() -} - -#[proc_macro] -pub fn main(_: proc_macro::TokenStream) -> proc_macro::TokenStream { - let mut modules = quote::quote! {}; - let mut init = quote::quote! {}; - for x in LIST { - let ident = syn::Ident::new(x.target_cpu, proc_macro2::Span::mixed_site()); - let target_cpu = x.target_cpu; - let list = LIST - .iter() - .find(|list| list.target_cpu == target_cpu) - .expect("unknown target_cpu"); - let target_arch = list.target_arch; - let target_features = list.target_features.split(',').collect::>(); - modules.extend(quote::quote! { - #[cfg(target_arch = #target_arch)] - pub mod #ident { - use std::sync::atomic::{AtomicBool, Ordering}; - - static ATOMIC: AtomicBool = AtomicBool::new(false); - - #[cfg(target_arch = "x86_64")] - pub fn test() -> bool { - true #(&& std::arch::is_x86_feature_detected!(#target_features))* - } - - #[cfg(target_arch = "aarch64")] - pub fn test() -> bool { - true #(&& std::arch::is_aarch64_feature_detected!(#target_features))* - } - - #[cfg(target_arch = "riscv64")] - pub fn test() -> bool { - true #(&& std::arch::is_riscv_feature_detected!(#target_features))* - } - - pub(crate) fn init() { - ATOMIC.store(test(), Ordering::Relaxed); - } - - pub fn detect() -> bool { - ATOMIC.load(Ordering::Relaxed) - } - } - }); - init.extend(quote::quote! { - #[cfg(target_arch = #target_arch)] - self::#ident::init(); - }); - } - quote::quote! { - #modules - pub fn init() { - #init - } - } - .into() -} diff --git a/crates/pyvectors/src/lib.rs b/crates/pyvectors/src/lib.rs index 793aa25ea..7c0d865ad 100644 --- a/crates/pyvectors/src/lib.rs +++ b/crates/pyvectors/src/lib.rs @@ -17,7 +17,6 @@ use with_labels::WithLabels; #[pymodule] fn vectors(m: &Bound<'_, PyModule>) -> PyResult<()> { - base::simd::enable(); m.add_class::()?; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index a7d3068ec..82c198485 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,6 @@ unsafe extern "C" fn _PG_init() { bad_init(); } unsafe { - base::simd::enable(); gucs::init(); index::init(); ipc::init();