Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.32】为 Paddle 优化 expand_as 前向&反向 op 在 GPU 上的计算性能 #52700

Merged
merged 5 commits into from
Apr 17, 2023

Conversation

Timber-Ye
Copy link
Contributor

@Timber-Ye Timber-Ye commented Apr 9, 2023

PR types

Performance optimization

PR changes

OPs

Describe

目前 Paddle 内 expand_as 前向和反向算子的 GPU 实现采用 Eigen 组合的模式,缺少 GPU Kernel,性能相对不足,希望实现高性能的 GPU 计算 Kernel,为 Paddle 优化 expand_as op 在 GPU 上的计算性能。

  • 开发环境
  1. 设备:Tesla V100-32G
  2. CUDA 11.2,cuDNN v8.1.1
  • 优化方法

【算子性能优化设计文档】

由于expand_as前向的过程与广播机制类似,后向的过程与求和约归类似,因此直接通过使用飞桨内部的 BroadcastKernelReduceKernel 来对expand_as算子进行优化。

  • 优化效果

完成优化后,Paddle(Optimized)与优化前的Paddle(Baseline)的性能对比:

Case Data type src_shape dst_shape Paddle Baseline(ms) Optimized(ms) Diff
0 float16 [1785, 1] [1785, 128] 0.1971 0.1166 faster than 40.835%
1 float16 [5, 1, 1] [5, 128, 128] 3.0909 0.1269 faster than 95.895%
2 float16 [32, 807, 1] [32, 807, 807] 1.3884 0.3940 faster than 71.620%
3 float32 [1785, 1] [1785, 128] 0.2244 0.1150 faster than 48.760%
4 float32 [5, 1, 1] [5, 128, 128] 3.6155 0.1179 faster than 96.738%
5 float32 [32, 807, 1] [32, 807, 807] 1.4826 0.6428 faster than 56.646%
6 float64 [32, 1, 1] [32, 807, 807] 288.7776 1.2293 faster than 99.570%
7 float64 [1, 1, 64 ,5] [64, 128, 64, 5] 3.1326 0.2746 faster than 91.645%
8 float64 [5, 1, 1, 1, 1] [5, 1, 713, 1, 889] 240.8861 0.2960 faster than 99.877%

针对以上9种不同的case, 优化后的性能有所提升,并且要扩展的Tensor元素数量越多,性能提升越明显,优化后的算子在case 8上的用时更是直接缩短至baseline的1/814。

原始PR:

Timber-Ye and others added 3 commits April 8, 2023 22:31
Co-authored-by: Timber-Ye <ye_hanqiao@163.com>
Co-authored-by: BrianQian1999 <brianqianhitsz@gmail.com>
@paddle-bot
Copy link

paddle-bot bot commented Apr 9, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Apr 9, 2023
@BrianQian1999
Copy link
Contributor

Hi, we need TPM approval to pass the static test, so to complete the CI. @jzhang533 @wanglun @lileding @Superjomn

self.assertTrue(
out_purefp16_fp32.dtype == fluid.core.VarDesc.VarType.FP32
)
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分注释掉的原因以及影响?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的assertion是因为原先的expand_as算子不支持fp16类型,我们添加了fp16支持后即可去掉此处的断言;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块如果注册fp16报错应该找到原因并进行修改,而不是直接把报错全部注释掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经复原,不需要注册fp16了

@luotao1
Copy link
Contributor

luotao1 commented Apr 11, 2023

@Timber-Ye @BrianQian1999 经过黑客松组委会讨论:

  1. 辛苦将fp16注册的代码移除,移除后的PR review通过后即可合入
  2. 如果有兴趣可以继续解fp16注册的问题,解决后额外再送一个小礼品

cc @JamesLim-sy @zhangbopd

@Timber-Ye
Copy link
Contributor Author

@Timber-Ye @BrianQian1999 经过黑客松组委会讨论:

  1. 辛苦将fp16注册的代码移除,移除后的PR review通过后即可合入
  2. 如果有兴趣可以继续解fp16注册的问题,解决后额外再送一个小礼品

cc @JamesLim-sy @zhangbopd

Done.

#include "paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"

#define MAX_RANK_SUPPORTED 6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

尽量避免用宏替换吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

PADDLE_ENFORCE_LE(
out_rank,
MAX_RANK_SUPPORTED,
errors::InvalidArgument("The rank of the input 'Out@GRAD' for "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

读起来不太通顺

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问哪里不通顺呢?这里直接借鉴的是文件expand_grad_kernel_impl.h L94-L101

Copy link
Contributor

@zhangbopd zhangbopd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JamesLim-sy JamesLim-sy merged commit 3c44e94 into PaddlePaddle:develop Apr 17, 2023
jjyaoao pushed a commit to jjyaoao/Paddle that referenced this pull request Apr 19, 2023
…addle#52700)

* Implement optimized kernel for OP-expand_as.

* Support fp16.

Co-authored-by: Timber-Ye <ye_hanqiao@163.com>
Co-authored-by: BrianQian1999 <brianqianhitsz@gmail.com>

* remove fp16 support

* remove MAX_RANK_SUPPORTED

---------

Co-authored-by: BrianQian1999 <brianqianhitsz@gmail.com>
@luotao1
Copy link
Contributor

luotao1 commented Apr 20, 2023

hi, @Timber-Ye

  • 非常感谢你对飞桨框架的贡献,我们正在运营一个PFCC组织,会通过定期分享技术知识与发布开发者主导任务的形式持续为飞桨框架做贡献,详情可见 https://github.com/luotao1 主页说明。
  • 如果你对PFCC有兴趣,请发送邮件至 ext_paddle_oss@baidu.com,我们会邀请你加入~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants