Skip to content

Commit

Permalink
chore: add (S)SSE & NEON & SVE impl
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi committed Dec 13, 2024
1 parent 954f256 commit 8fda7a5
Show file tree
Hide file tree
Showing 24 changed files with 3,177 additions and 1,144 deletions.
30 changes: 4 additions & 26 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Rust

on:
push:
branches: ["main", "vchord"]
branches: ["vchord"]
paths:
- ".cargo/**"
- ".github/**"
Expand All @@ -16,7 +16,7 @@ on:
- "vectors.control"
- "vendor/**"
pull_request:
branches: ["main", "vchord"]
branches: ["vchord"]
paths:
- ".cargo/**"
- ".github/**"
Expand Down Expand Up @@ -60,6 +60,8 @@ jobs:
sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*'
sudo apt-get purge -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*'
sudo apt-get update
curl --proto '=https' --tlsv1.2 -sSf https://apt.llvm.org/llvm.sh | sudo bash -s -- 18
sudo update-alternatives --install /usr/bin/clang clang $(which clang-18) 255
sudo apt-get install -y build-essential crossbuild-essential-arm64
sudo apt-get install -y qemu-user-static
touch ~/.cargo/config.toml
Expand All @@ -81,33 +83,9 @@ jobs:
- name: Clippy
run: |
cargo clippy --workspace --exclude pgvectors --exclude pyvectors --target $ARCH-unknown-linux-gnu
export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg14_${ARCH}_debian/pg_config/pg_config
export PGRX_TARGET_INFO_PATH_PG14=$(pwd)/vendor/pg14_${ARCH}_debian/pgrx_binding
cargo clippy --package pgvectors --features pg14 --no-deps --target $ARCH-unknown-linux-gnu
export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg15_${ARCH}_debian/pg_config/pg_config
export PGRX_TARGET_INFO_PATH_PG15=$(pwd)/vendor/pg15_${ARCH}_debian/pgrx_binding
cargo clippy --package pgvectors --features pg15 --no-deps --target $ARCH-unknown-linux-gnu
export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg16_${ARCH}_debian/pg_config/pg_config
export PGRX_TARGET_INFO_PATH_PG16=$(pwd)/vendor/pg16_${ARCH}_debian/pgrx_binding
cargo clippy --package pgvectors --features pg16 --no-deps --target $ARCH-unknown-linux-gnu
export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg17_${ARCH}_debian/pg_config/pg_config
export PGRX_TARGET_INFO_PATH_PG17=$(pwd)/vendor/pg17_${ARCH}_debian/pgrx_binding
cargo clippy --package pgvectors --features pg17 --no-deps --target $ARCH-unknown-linux-gnu
- name: Build
run: |
cargo build --workspace --exclude pgvectors --exclude pyvectors --target $ARCH-unknown-linux-gnu
export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg14_${ARCH}_debian/pg_config/pg_config
export PGRX_TARGET_INFO_PATH_PG14=$(pwd)/vendor/pg14_${ARCH}_debian/pgrx_binding
cargo build --package pgvectors --lib --features pg14 --target $ARCH-unknown-linux-gnu
export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg15_${ARCH}_debian/pg_config/pg_config
export PGRX_TARGET_INFO_PATH_PG15=$(pwd)/vendor/pg15_${ARCH}_debian/pgrx_binding
cargo build --package pgvectors --lib --features pg15 --target $ARCH-unknown-linux-gnu
export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg16_${ARCH}_debian/pg_config/pg_config
export PGRX_TARGET_INFO_PATH_PG16=$(pwd)/vendor/pg16_${ARCH}_debian/pgrx_binding
cargo build --package pgvectors --lib --features pg16 --target $ARCH-unknown-linux-gnu
export PGRX_PG_CONFIG_PATH=$(pwd)/vendor/pg17_${ARCH}_debian/pg_config/pg_config
export PGRX_TARGET_INFO_PATH_PG17=$(pwd)/vendor/pg17_${ARCH}_debian/pgrx_binding
cargo build --package pgvectors --lib --features pg17 --target $ARCH-unknown-linux-gnu
- name: Test
run: |
cargo test --workspace --exclude pgvectors --exclude pyvectors --no-fail-fast --target $ARCH-unknown-linux-gnu
Expand Down
25 changes: 6 additions & 19 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion crates/base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ toml.workspace = true
validator.workspace = true

base_macros = { path = "../base_macros" }
detect = { path = "../detect" }

[build-dependencies]
cc = "1.2.3"

[lints]
workspace = true
8 changes: 8 additions & 0 deletions crates/base/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
fn main() {
println!("cargo::rerun-if-changed=cshim.c");
cc::Build::new()
.compiler("clang")
.file("cshim.c")
.opt_level(3)
.compile("base_cshim");
}
219 changes: 219 additions & 0 deletions crates/base/cshim.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#include <stddef.h>
#include <stdint.h>

#ifdef __aarch64__

#include <arm_neon.h>
#include <arm_sve.h>

__attribute__((target("v8.3a,fp16"))) float
fp16_reduce_sum_of_d2_v8_3a_fp16_unroll(__fp16 *__restrict a,
__fp16 *__restrict b, size_t n) {
float16x8_t d2_0 = vdupq_n_f16(0.0);
float16x8_t d2_1 = vdupq_n_f16(0.0);
float16x8_t d2_2 = vdupq_n_f16(0.0);
float16x8_t d2_3 = vdupq_n_f16(0.0);
while (n >= 32) {
float16x8_t x_0 = vld1q_f16(a + 0);
float16x8_t x_1 = vld1q_f16(a + 8);
float16x8_t x_2 = vld1q_f16(a + 16);
float16x8_t x_3 = vld1q_f16(a + 24);
float16x8_t y_0 = vld1q_f16(b + 0);
float16x8_t y_1 = vld1q_f16(b + 8);
float16x8_t y_2 = vld1q_f16(b + 16);
float16x8_t y_3 = vld1q_f16(b + 24);
a += 32;
b += 32;
n -= 32;
float16x8_t d_0 = vsubq_f16(x_0, y_0);
float16x8_t d_1 = vsubq_f16(x_1, y_1);
float16x8_t d_2 = vsubq_f16(x_2, y_2);
float16x8_t d_3 = vsubq_f16(x_3, y_3);
d2_0 = vfmaq_f16(d2_0, d_0, d_0);
d2_1 = vfmaq_f16(d2_1, d_1, d_1);
d2_2 = vfmaq_f16(d2_2, d_2, d_2);
d2_3 = vfmaq_f16(d2_3, d_3, d_3);
}
if (n > 0) {
__fp16 A[32] = {};
__fp16 B[32] = {};
for (size_t i = 0; i < n; i += 1) {
A[i] = a[i];
B[i] = b[i];
}
float16x8_t x_0 = vld1q_f16(A + 0);
float16x8_t x_1 = vld1q_f16(A + 8);
float16x8_t x_2 = vld1q_f16(A + 16);
float16x8_t x_3 = vld1q_f16(A + 24);
float16x8_t y_0 = vld1q_f16(B + 0);
float16x8_t y_1 = vld1q_f16(B + 8);
float16x8_t y_2 = vld1q_f16(B + 16);
float16x8_t y_3 = vld1q_f16(B + 24);
float16x8_t d_0 = vsubq_f16(x_0, y_0);
float16x8_t d_1 = vsubq_f16(x_1, y_1);
float16x8_t d_2 = vsubq_f16(x_2, y_2);
float16x8_t d_3 = vsubq_f16(x_3, y_3);
d2_0 = vfmaq_f16(d2_0, d_0, d_0);
d2_1 = vfmaq_f16(d2_1, d_1, d_1);
d2_2 = vfmaq_f16(d2_2, d_2, d_2);
d2_3 = vfmaq_f16(d2_3, d_3, d_3);
}
float16x8_t d2 = vaddq_f16(vaddq_f16(d2_0, d2_1), vaddq_f16(d2_2, d2_3));
return vgetq_lane_f16(d2, 0) + vgetq_lane_f16(d2, 1) + vgetq_lane_f16(d2, 2) +
vgetq_lane_f16(d2, 3) + vgetq_lane_f16(d2, 4) + vgetq_lane_f16(d2, 5) +
vgetq_lane_f16(d2, 6) + vgetq_lane_f16(d2, 7);
}

__attribute__((target("v8.3a,fp16"))) float
fp16_reduce_sum_of_xy_v8_3a_fp16_unroll(__fp16 *__restrict a,
__fp16 *__restrict b, size_t n) {
float16x8_t xy_0 = vdupq_n_f16(0.0);
float16x8_t xy_1 = vdupq_n_f16(0.0);
float16x8_t xy_2 = vdupq_n_f16(0.0);
float16x8_t xy_3 = vdupq_n_f16(0.0);
while (n >= 32) {
float16x8_t x_0 = vld1q_f16(a + 0);
float16x8_t x_1 = vld1q_f16(a + 8);
float16x8_t x_2 = vld1q_f16(a + 16);
float16x8_t x_3 = vld1q_f16(a + 24);
float16x8_t y_0 = vld1q_f16(b + 0);
float16x8_t y_1 = vld1q_f16(b + 8);
float16x8_t y_2 = vld1q_f16(b + 16);
float16x8_t y_3 = vld1q_f16(b + 24);
a += 32;
b += 32;
n -= 32;
xy_0 = vfmaq_f16(xy_0, x_0, y_0);
xy_1 = vfmaq_f16(xy_1, x_1, y_1);
xy_2 = vfmaq_f16(xy_2, x_2, y_2);
xy_3 = vfmaq_f16(xy_3, x_3, y_3);
}
if (n > 0) {
__fp16 A[32] = {};
__fp16 B[32] = {};
for (size_t i = 0; i < n; i += 1) {
A[i] = a[i];
B[i] = b[i];
}
float16x8_t x_0 = vld1q_f16(A + 0);
float16x8_t x_1 = vld1q_f16(A + 8);
float16x8_t x_2 = vld1q_f16(A + 16);
float16x8_t x_3 = vld1q_f16(A + 24);
float16x8_t y_0 = vld1q_f16(B + 0);
float16x8_t y_1 = vld1q_f16(B + 8);
float16x8_t y_2 = vld1q_f16(B + 16);
float16x8_t y_3 = vld1q_f16(B + 24);
xy_0 = vfmaq_f16(xy_0, x_0, y_0);
xy_1 = vfmaq_f16(xy_1, x_1, y_1);
xy_2 = vfmaq_f16(xy_2, x_2, y_2);
xy_3 = vfmaq_f16(xy_3, x_3, y_3);
}
float16x8_t xy = vaddq_f16(vaddq_f16(xy_0, xy_1), vaddq_f16(xy_2, xy_3));
return vgetq_lane_f16(xy, 0) + vgetq_lane_f16(xy, 1) + vgetq_lane_f16(xy, 2) +
vgetq_lane_f16(xy, 3) + vgetq_lane_f16(xy, 4) + vgetq_lane_f16(xy, 5) +
vgetq_lane_f16(xy, 6) + vgetq_lane_f16(xy, 7);
}

__attribute__((target("v8.3a,sve"))) float
fp16_reduce_sum_of_d2_v8_3a_sve(__fp16 *__restrict a, __fp16 *__restrict b,
size_t n) {
svfloat16_t d2 = svdup_f16(0.0);
for (size_t i = 0; i < n; i += svcnth()) {
svbool_t mask = svwhilelt_b16(i, n);
svfloat16_t x = svld1_f16(mask, a + i);
svfloat16_t y = svld1_f16(mask, b + i);
svfloat16_t d = svsub_f16_x(mask, x, y);
d2 = svmla_f16_x(mask, d2, d, d);
}
return svaddv_f16(svptrue_b16(), d2);
}

__attribute__((target("v8.3a,sve"))) float
fp16_reduce_sum_of_xy_v8_3a_sve(__fp16 *__restrict a, __fp16 *__restrict b,
size_t n) {
svfloat16_t xy = svdup_f16(0.0);
for (size_t i = 0; i < n; i += svcnth()) {
svbool_t mask = svwhilelt_b16(i, n);
svfloat16_t x = svld1_f16(mask, a + i);
svfloat16_t y = svld1_f16(mask, b + i);
xy = svmla_f16_x(mask, xy, x, y);
}
return svaddv_f16(svptrue_b16(), xy);
}

__attribute__((target("v8.3a,sve"))) float
fp32_reduce_sum_of_x_v8_3a_sve(float *__restrict this, size_t n) {
svfloat32_t sum = svdup_f32(0.0);
for (size_t i = 0; i < n; i += svcntw()) {
svbool_t mask = svwhilelt_b32(i, n);
svfloat32_t x = svld1_f32(mask, this + i);
sum = svadd_f32_x(mask, sum, x);
}
return svaddv_f32(svptrue_b32(), sum);
}

__attribute__((target("v8.3a,sve"))) float
fp32_reduce_sum_of_abs_x_v8_3a_sve(float *__restrict this, size_t n) {
svfloat32_t sum = svdup_f32(0.0);
for (size_t i = 0; i < n; i += svcntw()) {
svbool_t mask = svwhilelt_b32(i, n);
svfloat32_t x = svld1_f32(mask, this + i);
sum = svadd_f32_x(mask, sum, svabs_f32_x(mask, x));
}
return svaddv_f32(svptrue_b32(), sum);
}

__attribute__((target("v8.3a,sve"))) float
fp32_reduce_sum_of_x2_v8_3a_sve(float *__restrict this, size_t n) {
svfloat32_t sum = svdup_f32(0.0);
for (size_t i = 0; i < n; i += svcntw()) {
svbool_t mask = svwhilelt_b32(i, n);
svfloat32_t x = svld1_f32(mask, this + i);
sum = svmla_f32_x(mask, sum, x, x);
}
return svaddv_f32(svptrue_b32(), sum);
}

__attribute__((target("v8.3a,sve"))) void
fp32_reduce_min_max_of_x_v8_3a_sve(float *__restrict this, size_t n,
float *out_min, float *out_max) {
svfloat32_t min = svdup_f32(1.0 / 0.0);
svfloat32_t max = svdup_f32(-1.0 / 0.0);
for (size_t i = 0; i < n; i += svcntw()) {
svbool_t mask = svwhilelt_b32(i, n);
svfloat32_t x = svld1_f32(mask, this + i);
min = svmin_f32_x(mask, min, x);
max = svmax_f32_x(mask, max, x);
}
*out_min = svminv_f32(svptrue_b32(), min);
*out_max = svmaxv_f32(svptrue_b32(), max);
}

__attribute__((target("v8.3a,sve"))) float
fp32_reduce_sum_of_xy_v8_3a_sve(float *__restrict lhs, float *__restrict rhs,
size_t n) {
svfloat32_t sum = svdup_f32(0.0);
for (size_t i = 0; i < n; i += svcntw()) {
svbool_t mask = svwhilelt_b32(i, n);
svfloat32_t x = svld1_f32(mask, lhs + i);
svfloat32_t y = svld1_f32(mask, rhs + i);
sum = svmla_f32_x(mask, sum, x, y);
}
return svaddv_f32(svptrue_b32(), sum);
}

__attribute__((target("v8.3a,sve"))) float
fp32_reduce_sum_of_d2_v8_3a_sve(float *__restrict lhs, float *__restrict rhs,
size_t n) {
svfloat32_t sum = svdup_f32(0.0);
for (size_t i = 0; i < n; i += svcntw()) {
svbool_t mask = svwhilelt_b32(i, n);
svfloat32_t x = svld1_f32(mask, lhs + i);
svfloat32_t y = svld1_f32(mask, rhs + i);
svfloat32_t d = svsub_f32_x(mask, x, y);
sum = svmla_f32_x(mask, sum, d, d);
}
return svaddv_f32(svptrue_b32(), sum);
}

#endif
1 change: 1 addition & 0 deletions crates/base/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![feature(target_feature_11)]
#![feature(avx512_target_feature)]
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))]
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512_f16))]
Expand Down
Loading

0 comments on commit 8fda7a5

Please sign in to comment.