-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize layer norm forward when cols is 1024. #39167
Conversation
… fused_dropout_op_support_ln_fp16
Thanks for your contribution! |
@@ -19,6 +19,9 @@ limitations under the License. */ | |||
namespace paddle { | |||
namespace operators { | |||
|
|||
#define DIVUP(x, y) (((x) + ((y)-1)) / (y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer using function instead of macro.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done by using std::ceil function.
@@ -19,6 +19,9 @@ limitations under the License. */ | |||
namespace paddle { | |||
namespace operators { | |||
|
|||
#define DIVUP(x, y) (((x) + ((y)-1)) / (y)) | |||
#define COLS_ 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about naming this with a more meaningful name? The name COLS_
is too simple to understand the exact meaning and easy to be conflict with the other macros.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, | ||
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA> | ||
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( | ||
void *__restrict__ y_, void *__restrict__ residual_out_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The void *
is too hard to read. Try to just write T *
, U *
or anything else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
void *__restrict__ y_, void *__restrict__ mean_out_, | ||
void *__restrict__ var_out_, const void *__restrict__ x_, | ||
const void *__restrict__ gamma_, const void *__restrict__ beta_, | ||
const float epsilon, int rows, int cols) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid to use void *
. Same above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
PR types
Performance optimization
PR changes
OPs
Describe
Optimize the performance of layer_norm forward kernel for layer_norm op and fused_dropout_reisdual_layer_norm op.
Performance results: (use nsys to collect kernel times)
(1) layer_norm op:
结论:优化后相比优化前获得2-5x加速;优化后基本打平竞品。在个别case相比竞品略差的原因是:在计算scale * (x-mean)/var + bias时,竞品均使用fp16进行计算,paddle采取转换为fp32,使用fp32进行计算的方法,后者相比前者有一些性能损耗。
(2) fused_dropout_reisudal_layer_norm op:
结论:优化后,fused_dropout_reisudal_layer_norm模块相比竞品,大约获得10%-20%左右加速比。