-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XPU][PHI Kernels] Support bf16 for fused_rope #60064
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather"); | ||
using XPUTFp16 = typename XPUTypeTrait<phi::dtype::float16>::Type; | ||
using XPUTBf16 = typename XPUTypeTrait<phi::dtype::bfloat16>::Type; | ||
if (std::is_same<XPUT, XPUTBf16>::value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里因为当时提PR的时候还没有支持bf16 gather的相关产出,所以暂时用了FP16,后续PR会改回BF16
self.shape = [2, 8, 2, 16] | ||
|
||
def test_api(self): | ||
q_fp32 = paddle.rand(self.shape, dtype="float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里BF16的单测如果用non_fusion的实现会遇到暂时不支持bf16的其他kernel,所以暂时和fp32的融合算子比
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM。
@XiaociZhang 帮忙看看分布式相关的单测吧。
@@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(c_embedding, | |||
ALL_LAYOUT, | |||
phi::CEmbeddingKernel, | |||
float, | |||
phi::dtype::float16) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
api最新代码 embedding 似乎还不支持bf16,再确认下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -86,4 +86,5 @@ PD_REGISTER_KERNEL(c_embedding_grad, | |||
ALL_LAYOUT, | |||
phi::CEmbeddingGradKernel, | |||
float, | |||
phi::dtype::float16) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -32,6 +32,7 @@ def test_identity(self): | |||
"float64", | |||
"int32", | |||
"int64", | |||
"bfloat16", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测修改看起来没问题,本地ctest验证过吗,我不确定这几个算子是否在CI黑名单里
LGTM |
* support fused_rope fp16 * bug fix * bug fix
PR types
New features
PR changes
OPs
Description