Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add x64 simd for sparse vector operations #417

Merged
merged 18 commits into from
Mar 13, 2024

Conversation

silver-ymz
Copy link
Member

Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
@usamoi
Copy link
Collaborator

usamoi commented Mar 12, 2024

@VoVAllen

Please grant permission of tensorchord/stdarch to @silver-ymz.

Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
@VoVAllen
Copy link
Member

done

@VoVAllen
Copy link
Member

And I think we can remove native vp2intersect since it's slower according to https://arxiv.org/abs/2112.06342

@silver-ymz
Copy link
Member Author

And I think we can remove native vp2intersect since it's slower according to https://arxiv.org/abs/2112.06342

Actually the paper says that if we only get one output mask, it's slower. But in our occasion, it needs two masks. https://ar5iv.labs.arxiv.org/html/2112.06342#:~:text=When%20used%20to,first%20output%20mask

Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
@VoVAllen
Copy link
Member

VoVAllen commented Mar 12, 2024

The gap looks small. Prefer to remove the native vp2intersect for simplicity. Intel appears to have deprecated this instruction in the latest CPU.. Probably hard to find the cpu with this support

Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
@silver-ymz
Copy link
Member Author

Prefer to remove the native vp2intersect for simplicity.

done

@VoVAllen VoVAllen requested a review from usamoi March 12, 2024 11:33
// Instructions. arXiv preprint arXiv:2112.06342.
#[inline]
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512bw,avx512f")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it inlined? It seems that this function suffers from the compiler bug too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cargo asm shows that it's inlined.

❯ cargo asm -p base 692
    Finished `release` profile [optimized + debuginfo] target(s) in 0.05s

.section .text.base::global::svecf32::dot_v4,"ax",@progbits
	.p2align	4, 0x90
	.type	base::global::svecf32::dot_v4,@function
base::global::svecf32::dot_v4:
	.cfi_startproc
	push rbp
	.cfi_def_cfa_offset 16
	push r15
	.cfi_def_cfa_offset 24
	push r14
	.cfi_def_cfa_offset 32
	push r13
	.cfi_def_cfa_offset 40
	push r12
	.cfi_def_cfa_offset 48
	push rbx
	.cfi_def_cfa_offset 56
	sub rsp, 56
	.cfi_def_cfa_offset 112
	.cfi_offset rbx, -56
	.cfi_offset r12, -48
	.cfi_offset r13, -40
	.cfi_offset r14, -32
	.cfi_offset r15, -24
	.cfi_offset rbp, -16
	mov rax, qword ptr [rdi + 8]
	mov rcx, rax
	shr rcx, 32
	jne .LBB97_25
	mov rcx, qword ptr [rsi + 8]
	mov rdx, rcx
	shr rdx, 32
	jne .LBB97_25
	xor r9d, r9d
	mov r11d, 4294967280
	mov rdx, qword ptr [rdi]
	mov rdi, qword ptr [rdi + 16]
	mov qword ptr [rsp + 48], rdi
	mov r8, qword ptr [rsi]
	mov rsi, qword ptr [rsi + 16]
	mov qword ptr [rsp + 40], rsi
	vxorps xmm0, xmm0, xmm0
	and r11, rax
	je .LBB97_3
	mov ebx, ecx
	mov r10d, 0
	and ebx, -16
	je .LBB97_4
	vxorps xmm1, xmm1, xmm1
	xor r14d, r14d
	xor r15d, r15d
	mov qword ptr [rsp + 32], rax
	mov qword ptr [rsp + 24], rcx
	.p2align	4, 0x90
.LBB97_9:
	lea r9, [r15 + 16]
	lea rsi, [r14 + 16]
	mov r10d, dword ptr [rdx + 4*r15 + 60]
	xor edi, edi
	cmp r10d, dword ptr [r8 + 4*r14 + 60]
	setne dil
	mov eax, 255
	cmovb edi, eax
	test dil, dil
	je .LBB97_12
	movzx edi, dil
	mov r10, r14
	cmp edi, 1
	jne .LBB97_13
	mov r9, r15
.LBB97_12:
	mov r10, rsi
.LBB97_13:
	vmovdqu64 zmm4, zmmword ptr [rdx + 4*r15]
	vmovdqu64 zmm0, zmmword ptr [r8 + 4*r14]
	vshufi64x2 zmm5, zmm4, zmm4, 57
	vshufi64x2 zmm3, zmm4, zmm4, 78
	vshufi64x2 zmm2, zmm4, zmm4, 147
	vpshufd zmm6, zmm0, 57
	vpshufd zmm7, zmm0, 78
	vpcmpeqd k0, zmm4, zmm0
	kmovd r12d, k0
	vpcmpeqd k1, zmm4, zmm6
	kmovw word ptr [rsp + 16], k1
	vpcmpeqd k0, zmm4, zmm7
	kmovw word ptr [rsp + 20], k0
	vpcmpeqd k3, zmm5, zmm0
	vpcmpeqd k2, zmm5, zmm6
	korw k4, k2, k3
	kmovw word ptr [rsp + 10], k4
	kmovd r13d, k3
	vpcmpeqd k4, zmm5, zmm7
	vpcmpeqd k6, zmm3, zmm6
	vpcmpeqd k3, zmm2, zmm6
	kmovw word ptr [rsp + 14], k3
	korw k2, k1, k2
	korw k3, k6, k3
	korw k2, k2, k3
	vpcmpeqd k3, zmm3, zmm7
	kmovd ebp, k2
	vpcmpeqd k1, zmm2, zmm7
	kmovw word ptr [rsp + 12], k1
	korw k2, k0, k4
	korw k7, k3, k1
	korw k2, k2, k7
	kmovd edi, k2
	vpshufd zmm6, zmm0, 147
	vpcmpeqd k2, zmm4, zmm6
	kmovw word ptr [rsp + 18], k2
	vpcmpeqd k0, zmm5, zmm6
	vpcmpeqd k1, zmm3, zmm6
	vpcmpeqd k7, zmm2, zmm6
	korw k2, k2, k0
	korw k5, k1, k7
	korw k2, k2, k5
	kmovd esi, k2
	korw k0, k4, k0
	kmovw k2, word ptr [rsp + 10]
	korw k0, k2, k0
	kmovd ecx, k0
	korw k0, k3, k1
	vpcmpeqd k3, zmm3, zmm0
	korw k1, k6, k3
	korw k0, k1, k0
	kmovd eax, k0
	kmovw k0, word ptr [rsp + 12]
	korw k0, k0, k7
	vpcmpeqd k4, zmm2, zmm0
	kmovw k1, word ptr [rsp + 14]
	korw k1, k1, k4
	korw k0, k1, k0
	rol cx, 4
	rol ax, 8
	or eax, ecx
	kmovd ecx, k0
	rol cx, 12
	or ecx, r12d
	or ecx, eax
	kmovd k0, ecx
	kmovw k1, word ptr [rsp + 16]
	korw k1, k0, k1
	mov eax, ebp
	and eax, -2185
	shr ebp, 3
	and ebp, 4369
	lea ecx, [4*rdi]
	and ecx, -13108
	shr edi, 2
	and edi, 13107
	or edi, ecx
	lea ecx, [8*rsi]
	and ecx, -30584
	shr esi
	and esi, 30583
	or esi, ecx
	lea eax, [rbp + 2*rax]
	or eax, edi
	or esi, r13d
	or esi, eax
	or esi, r12d
	kmovd k0, esi
	korw k0, k0, k3
	kmovw k2, word ptr [rsp + 20]
	kmovw k3, word ptr [rsp + 18]
	korw k2, k2, k3
	korw k1, k1, k2
	mov rax, qword ptr [rsp + 48]
	vmovdqu64 zmm0, zmmword ptr [rax + 4*r15]
	vcompressps zmm2 {k1} {z}, zmm0
	korw k1, k0, k4
	mov rax, qword ptr [rsp + 40]
	vmovdqu64 zmm0, zmmword ptr [rax + 4*r14]
	vcompressps zmm0 {k1} {z}, zmm0
	vfmadd213ps zmm0, zmm2, zmm1
	cmp r9, r11
	jae .LBB97_14
	mov r14, r10
	mov r15, r9
	vmovaps zmm1, zmm0
	cmp r10, rbx
	mov rax, qword ptr [rsp + 32]
	mov rcx, qword ptr [rsp + 24]
	jb .LBB97_9
	jmp .LBB97_4
.LBB97_14:
	mov rax, qword ptr [rsp + 32]
	mov rcx, qword ptr [rsp + 24]
.LBB97_4:
	cmp r9, rax
	jae .LBB97_6
.LBB97_5:
	cmp r10, rcx
	jae .LBB97_6
	mov r11d, 16
	mov ebx, 65535
	.p2align	4, 0x90
.LBB97_16:
	mov rsi, rax
	sub rsi, r9
	cmp rsi, 16
	cmovae rsi, r11
	mov rdi, rcx
	sub rdi, r10
	cmp rdi, 16
	cmovae rdi, r11
	bzhi ebp, ebx, esi
	kmovd k1, ebp
	lea r14, [rdx + 4*r9]
	#APP

	vmovdqu32 zmm4 {k1} {z}, zmmword ptr [r14]

	#NO_APP
	bzhi ebp, ebx, edi
	kmovd k2, ebp
	lea r14, [r8 + 4*r10]
	#APP

	vmovdqu32 zmm3 {k2} {z}, zmmword ptr [r14]

	#NO_APP
	mov r14, qword ptr [rsp + 48]
	lea r14, [r14 + 4*r9]
	#APP

	vmovups zmm1 {k1} {z}, zmmword ptr [r14]

	#NO_APP
	mov r14, qword ptr [rsp + 40]
	lea r14, [r14 + 4*r10]
	#APP

	vmovups zmm2 {k2} {z}, zmmword ptr [r14]

	#NO_APP
	add rsi, r9
	add rdi, r10
	mov ebp, dword ptr [rdx + 4*rsi - 4]
	xor esi, esi
	cmp ebp, dword ptr [r8 + 4*rdi - 4]
	setne sil
	mov edi, 255
	cmovb esi, edi
	test sil, sil
	je .LBB97_19
	movzx esi, sil
	cmp esi, 1
	je .LBB97_20
	add r9, 16
	jmp .LBB97_21
	.p2align	4, 0x90
.LBB97_19:
	add r9, 16
.LBB97_20:
	add r10, 16
.LBB97_21:
	vshufi64x2 zmm7, zmm4, zmm4, 57
	vshufi64x2 zmm6, zmm4, zmm4, 78
	vshufi64x2 zmm5, zmm4, zmm4, 147
	vshufps zmm8, zmm3, zmm3, 57
	vshufpd zmm9, zmm3, zmm3, 85
	vpcmpeqd k0, zmm4, zmm3
	kmovd r14d, k0
	vpcmpeqd k1, zmm8, zmm4
	kmovw word ptr [rsp + 16], k1
	vpcmpeqd k0, zmm9, zmm4
	kmovw word ptr [rsp + 20], k0
	vpcmpeqd k3, zmm7, zmm3
	vpcmpeqd k2, zmm7, zmm8
	korw k4, k2, k3
	kmovw word ptr [rsp + 10], k4
	kmovd r15d, k3
	vpcmpeqd k4, zmm7, zmm9
	vpcmpeqd k6, zmm6, zmm8
	vpcmpeqd k3, zmm5, zmm8
	kmovw word ptr [rsp + 14], k3
	korw k2, k1, k2
	korw k3, k6, k3
	korw k2, k2, k3
	vpcmpeqd k3, zmm6, zmm9
	kmovd r12d, k2
	vpcmpeqd k1, zmm5, zmm9
	kmovw word ptr [rsp + 12], k1
	korw k2, k0, k4
	korw k7, k3, k1
	korw k2, k2, k7
	kmovd r13d, k2
	vpshufd zmm8, zmm3, 147
	vpcmpeqd k2, zmm8, zmm4
	kmovw word ptr [rsp + 18], k2
	vpcmpeqd k0, zmm7, zmm8
	vpcmpeqd k1, zmm6, zmm8
	vpcmpeqd k7, zmm5, zmm8
	korw k2, k2, k0
	korw k5, k1, k7
	korw k2, k2, k5
	kmovd ebp, k2
	korw k0, k4, k0
	kmovw k2, word ptr [rsp + 10]
	korw k0, k2, k0
	kmovd esi, k0
	korw k0, k3, k1
	vpcmpeqd k3, zmm6, zmm3
	korw k1, k6, k3
	korw k0, k1, k0
	kmovd edi, k0
	kmovw k0, word ptr [rsp + 12]
	korw k0, k0, k7
	vpcmpeqd k4, zmm5, zmm3
	kmovw k1, word ptr [rsp + 14]
	korw k1, k1, k4
	korw k0, k1, k0
	rol si, 4
	rol di, 8
	or edi, esi
	kmovd esi, k0
	rol si, 12
	or esi, r14d
	or esi, edi
	kmovd k0, esi
	kmovw k1, word ptr [rsp + 16]
	korw k1, k0, k1
	mov esi, r12d
	and esi, -2185
	shr r12d, 3
	and r12d, 4369
	lea edi, [4*r13]
	and edi, -13108
	shr r13d, 2
	and r13d, 13107
	or r13d, edi
	lea edi, [8*rbp]
	and edi, -30584
	shr ebp
	and ebp, 30583
	or ebp, edi
	lea esi, [r12 + 2*rsi]
	or esi, r13d
	or ebp, r15d
	or ebp, esi
	or ebp, r14d
	kmovd k0, ebp
	korw k0, k0, k3
	kmovw k2, word ptr [rsp + 20]
	kmovw k3, word ptr [rsp + 18]
	korw k2, k2, k3
	korw k1, k1, k2
	vcompressps zmm3 {k1} {z}, zmm1
	korw k1, k0, k4
	vcompressps zmm1 {k1} {z}, zmm2
	vfmadd213ps zmm1, zmm3, zmm0
	cmp r9, rax
	jae .LBB97_23
	vmovaps zmm0, zmm1
	cmp r10, rcx
	jb .LBB97_16
	jmp .LBB97_23
.LBB97_3:
	xor r10d, r10d
	cmp r9, rax
	jb .LBB97_5
.LBB97_6:
	vmovaps zmm1, zmm0
.LBB97_23:
	vextractf64x4 ymm0, zmm1, 1
	vaddps zmm0, zmm1, zmm0
	vextractf128 xmm1, ymm0, 1
	vaddps xmm0, xmm0, xmm1
	vshufpd xmm1, xmm0, xmm0, 1
	vaddps xmm0, xmm0, xmm1
	vmovshdup xmm1, xmm0
	vaddss xmm0, xmm0, xmm1
	vxorps xmm1, xmm1, xmm1
	vaddss xmm0, xmm0, xmm1
	add rsp, 56
	.cfi_def_cfa_offset 56
	pop rbx
	.cfi_def_cfa_offset 48
	pop r12
	.cfi_def_cfa_offset 40
	pop r13
	.cfi_def_cfa_offset 32
	pop r14
	.cfi_def_cfa_offset 24
	pop r15
	.cfi_def_cfa_offset 16
	pop rbp
	.cfi_def_cfa_offset 8
	vzeroupper
	ret
.LBB97_25:
	.cfi_def_cfa_offset 112
	lea rdi, [rip + .Lanon.77122362e8253e66b52972cd6fb83991.3]
	lea rcx, [rip + .Lanon.77122362e8253e66b52972cd6fb83991.4]
	lea r8, [rip + .Lanon.77122362e8253e66b52972cd6fb83991.148]
	lea rdx, [rsp + 23]
	mov esi, 43
	call qword ptr [rip + core::result::unwrap_failed@GOTPCREL]

usamoi
usamoi previously approved these changes Mar 13, 2024
Copy link
Collaborator

@usamoi usamoi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please resolve conflicts.

Signed-off-by: Mingzhuo Yin <yinmingzhuo@gmail.com>
Merged via the queue into tensorchord:main with commit bd7a0a6 Mar 13, 2024
8 checks passed
@silver-ymz silver-ymz deleted the feat/sparse-simd branch March 13, 2024 08:00
@usamoi
Copy link
Collaborator

usamoi commented Mar 13, 2024

Tests failed after being merged.

@silver-ymz
Copy link
Member Author

thread 'vector::veci8::tests_2::test_cos_i8' panicked at crates/base/src/vector/veci8.rs:552:9:
assertion failed: (result.0 - result_expected).abs() / result_expected < 0.25
stack backtrace:
   0: rust_begin_unwind
             at /rustc/516b6162a2ea8e66678c09e8243ebd83e4b8eeea/library/std/src/panicking.rs:645:5
   1: core::panicking::panic_fmt
             at /rustc/516b6162a2ea8e66678c09e8243ebd83e4b8eeea/library/core/src/panicking.rs:72:14
   2: core::panicking::panic
             at /rustc/516b6162a2ea8e66678c09e8243ebd83e4b8eeea/library/core/src/panicking.rs:144:5
   3: base::vector::veci8::tests_2::test_cos_i8
             at ./src/vector/veci8.rs:552:9
   4: base::vector::veci8::tests_2::test_cos_i8::{{closure}}
             at ./src/vector/veci8.rs:531:21
   5: core::ops::function::FnOnce::call_once
             at /rustc/516b6162a2ea8e66678c09e8243ebd83e4b8eeea/library/core/src/ops/function.rs:250:5
   6: core::ops::function::FnOnce::call_once
             at /rustc/516b6162a2ea8e66678c09e8243ebd83e4b8eeea/library/core/src/ops/function.rs:250:5
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
test vector::veci8::tests_2::test_cos_i8 ... FAILED

It seems related to veci8. cc @my-vegetable-has-exploded

@usamoi
Copy link
Collaborator

usamoi commented Mar 13, 2024

This test fails with a probability and I thought it's just because the value was too strict so I relaxed the constant 0.01 to 0.25. But it seems not helpful...

@my-vegetable-has-exploded
Copy link
Contributor

my-vegetable-has-exploded commented Mar 13, 2024

I tried to run this test for many times.
I think that it may fails when the cosine result is small.
图片
Do you think it is good idea to change it to codes below? @usamoi @silver-ymz

            assert!((result.0 - result_expected).abs() / result_expected < 0.25 || (result.0 - result_expected).abs() < 0.001);

@usamoi
Copy link
Collaborator

usamoi commented Mar 13, 2024

@my-vegetable-has-exploded Yes. It's good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants