Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU][PHI Kernels] Support bf16 for fused_rope #60064

Merged
merged 3 commits into from
Dec 19, 2023

Conversation

lj970926
Copy link
Contributor

@lj970926 lj970926 commented Dec 15, 2023

PR types

New features

PR changes

OPs

Description

  1. support bf16 for fused_rope、c_embedding、c_split、c_identity、shape

Copy link

paddle-bot bot commented Dec 15, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Dec 15, 2023
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather");
using XPUTFp16 = typename XPUTypeTrait<phi::dtype::float16>::Type;
using XPUTBf16 = typename XPUTypeTrait<phi::dtype::bfloat16>::Type;
if (std::is_same<XPUT, XPUTBf16>::value) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里因为当时提PR的时候还没有支持bf16 gather的相关产出,所以暂时用了FP16,后续PR会改回BF16

self.shape = [2, 8, 2, 16]

def test_api(self):
q_fp32 = paddle.rand(self.shape, dtype="float32")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里BF16的单测如果用non_fusion的实现会遇到暂时不支持bf16的其他kernel,所以暂时和fp32的融合算子比

Copy link
Contributor

@houj04 houj04 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM。
@XiaociZhang 帮忙看看分布式相关的单测吧。

@@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(c_embedding,
ALL_LAYOUT,
phi::CEmbeddingKernel,
float,
phi::dtype::float16) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

api最新代码 embedding 似乎还不支持bf16,再确认下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -86,4 +86,5 @@ PD_REGISTER_KERNEL(c_embedding_grad,
ALL_LAYOUT,
phi::CEmbeddingGradKernel,
float,
phi::dtype::float16) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -32,6 +32,7 @@ def test_identity(self):
"float64",
"int32",
"int64",
"bfloat16",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测修改看起来没问题,本地ctest验证过吗,我不确定这几个算子是否在CI黑名单里

@XiaociZhang
Copy link
Contributor

LGTM

@houj04 houj04 merged commit 2a424a8 into PaddlePaddle:develop Dec 19, 2023
29 checks passed
HermitSun pushed a commit to HermitSun/Paddle that referenced this pull request Dec 21, 2023
* support fused_rope fp16

* bug fix

* bug fix
@houj04 houj04 added the XPU label Sep 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers XPU
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants