Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.52】为 Paddle dist 算子实现 float16 数据类型支持 #50915

Merged
merged 25 commits into from
Apr 28, 2023

Conversation

jinyouzhi
Copy link
Contributor

@jinyouzhi jinyouzhi commented Feb 25, 2023

PR types

New features

PR changes

OPs

Description

任务:#50658 (comment)

中文文档: PaddlePaddle/docs#5740

OP Performance:

OP shape p fp32 fp16
dist_forward [1000,1000] 2 0.05938422923185387 0.050949320501210746
dist_backward [1000,1000] 2 0.11299653929107042 0.08545359786675902
dist_forward [1000,1000] inf 0.0472954341343471 0.044048319057542445
dist_backward [1000,1000] inf 0.09125203502421475 0.08334651285288286
dist_forward [1000,1000] 0 0.04742607778432417 0.045565683014538824
dist_backward [1000,1000] 0 0.08797159000318877 0.08200139415507414

[used AI Studio]

@paddle-bot
Copy link

paddle-bot bot commented Feb 25, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Feb 25, 2023
@CLAassistant
Copy link

CLAassistant commented Feb 25, 2023

CLA assistant check
All committers have signed the CLA.

paddle/phi/kernels/cpu/p_norm_kernel.cc Outdated Show resolved Hide resolved
paddle/phi/kernels/dist_grad_kernel.cc Outdated Show resolved Hide resolved
paddle/phi/kernels/dist_kernel.cc Outdated Show resolved Hide resolved
python/paddle/tensor/linalg.py Outdated Show resolved Hide resolved
@jinyouzhi jinyouzhi force-pushed the fp16/dist branch 2 times, most recently from b1ee7df to 9edcfa2 Compare March 2, 2023 20:33
@jinyouzhi jinyouzhi marked this pull request as draft March 7, 2023 17:20
@jinyouzhi jinyouzhi force-pushed the fp16/dist branch 2 times, most recently from e60a97c to 8ff0da6 Compare March 13, 2023 07:23
@jinyouzhi jinyouzhi force-pushed the fp16/dist branch 3 times, most recently from 1ec1d67 to f7d2bef Compare March 19, 2023 12:36
@jinyouzhi jinyouzhi marked this pull request as ready for review March 19, 2023 16:02
@jinyouzhi
Copy link
Contributor Author

@zhangting2020 大部分CI都过了,能再review一下吗

@jinyouzhi
Copy link
Contributor Author

@zhangting2020 补充了性能数据

paddle/phi/kernels/funcs/math_cuda_utils.h Outdated Show resolved Hide resolved
val_ret += __shfl_xor(val_ret, mask, warpSize);
#endif
return static_cast<phi::dtype::float16>(val_ret);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里之前的修改版本,调用原始的实现不是可以正常编译通过和运行吗?
区别主要是else分支转成了fp32?这种场景不需要在算子层面去处理

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里原有实现的编译不过的原因是传入的是phi::dtype::float16,而cuda的函数参数半精度是__half,所以做了一个模板特化处理fp16。这块不太清楚fp16和cuda的half在框架里面是怎么衔接的。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了支持fp16跑通是可以这么改,但是从算子的计算精度上去有更多考虑:

  • 对于reduce sum这种运算fp16下容易损失精度,都是需要保持计算精度为fp32,输入输出fp16的。你需要从调用它的ReduceSumWithSubtract去看,到运行到这个函数时,输入的类型已经不应该是float16了。可能并不需要增加float16的支持。

如果你希望将这个函数写的更通用支持float16,那可以在原始的接口上稍作修改:

  • 可以在这个文件中 paddle/phi/backends/gpu/cuda/cuda_device_function.h找到CudaShuffleDownSync的接口,这是更推荐的写法。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 考虑到精度,inf/-inf path 因为是 max/min 不会有精度问题保留fp16写法,对 -inf<p<inf 的情况使用float32计算 (写得可能有点 naive,求comment 😝
  • math_cuda_utils: 改成调用CudaShuffleDownSync兼容fp16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangting2020 改完了,请问当前这个方案可以吗~

paddle/phi/kernels/funcs/math_cuda_utils.h Outdated Show resolved Hide resolved
@jinyouzhi jinyouzhi force-pushed the fp16/dist branch 4 times, most recently from f888d7c to 4922ad6 Compare March 27, 2023 17:51
@jinyouzhi
Copy link
Contributor Author

image

从历史记录看,性能还是有回退

回退不在dist,主要是interp_trilinearhistogram回退了,这两个都不依赖dist,可能是paddle/phi/kernels/funcs/math_cuda_utils.h修改的副作用。
我先rebase一次测一下,另外撤销对math_cuda_utils的修改改一版。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants