Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the perf of top_k when k is too large #40941

Merged
merged 6 commits into from
Mar 30, 2022

Conversation

ZzSean
Copy link
Contributor

@ZzSean ZzSean commented Mar 25, 2022

PR types

Performance optimization

PR changes

OPs

Describe

Optimize the perf of top_k when k is too large

  • 开发环境:
  1. 设备:V100-16G
  2. 环境:CUDA10.1,cuDNN 7
  • 优化方法:
  1. 采用基数排序的方法,从高位到低位,两个比特位为一组进行比较,每个slice由一个block参与,使用warp级别的intrinsic和shared memory进行个数的统计;
  2. 核心思想是选出第k大(小)的值,将大于(小于)该值的所有元素选出;
  3. 若该值不唯一,则根据index从小到大的顺序选出;
  4. 选出top k个值后,再使用cub对这k个值进行全排序。
config pytorch(ms) paddle优化前(ms) 对比 paddle优化后(ms) 对比 加速比
shape[136480],k=5000,fp32 1.12439 20.76580 差于 (17.47x) 0.86577 优于 (19.70%) 23.98x
shape[104903],k=5000,fp32 0.91584 17.57050 差于 (18.19x) 0.79751 打平 (2.29%) 22.03x
shape[133725],k=5000,fp32 1.10155 20.12270 差于 (17.27x) 0.86220 优于 (14.36%) 23.34x

@paddle-bot-old
Copy link

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy(dev_ctx, sorted_output, out->place(), false, out);
return;
} else {
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里使用LOG(INFO),但是后面报错信息有 errors,是不是使用LOG(ERROR)更适合?或者可以使用paddle的报错函数:phi::errors::XXX

Copy link
Contributor Author

@ZzSean ZzSean Mar 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个并不是说走到这个分支就要退出,而是可以选择后续的分支继续完成计算,所以用errors不太合适

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用VLOG

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx

int k_left = k;

#pragma unroll
for (int digit_pos = sizeof(T) * 8 - RADIX_BITS; digit_pos >= 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么是*8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1byte=8bit

limin2021
limin2021 previously approved these changes Mar 29, 2022
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZzSean ZzSean merged commit 45078d9 into PaddlePaddle:develop Mar 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants