Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize layer norm forward when cols is 1024. #39167

Merged
merged 11 commits into from
Jan 26, 2022

Conversation

limin2021
Copy link
Contributor

@limin2021 limin2021 commented Jan 24, 2022

PR types

Performance optimization

PR changes

OPs

Describe

Optimize the performance of layer_norm forward kernel for layer_norm op and fused_dropout_reisdual_layer_norm op.

Performance results: (use nsys to collect kernel times)
(1) layer_norm op:

batch_size*seq_len 28672 14336 7168 3584 1792 896 448 224 112 56
时间(ns):                    
apex_fast_layer_norm 74399.931 41035.3163 20555.8185 11111.16 8395.97375 6855.50875 6275.20375 5652.597 5469.5585 5408.00875
paddle 367119.365 187225.566 101543.132 54855.2878 29603.3463 18901.4535 13006.04 9529.146 8568.81325 8007.138
paddle_opt 74155.0323 40487.696 21343.357 11008.9173 8566.515 6866.36975 6272.02075 5826.07175 5712.10125 5598.60925
加速比:                    
apex/paddle 0.20265869 0.21917582 0.20243436 0.20255404 0.2836157 0.36269744 0.48248381 0.5931903 0.63830992 0.67539847
apex/paddle_opt 1.00330252 1.0135256 0.96310147 1.00928727 0.98009211 0.99841823 1.00050749 0.97022441 0.95753879 0.96595574

结论:优化后相比优化前获得2-5x加速;优化后基本打平竞品。在个别case相比竞品略差的原因是:在计算scale * (x-mean)/var + bias时,竞品均使用fp16进行计算,paddle采取转换为fp32,使用fp32进行计算的方法,后者相比前者有一些性能损耗。

(2) fused_dropout_reisudal_layer_norm op:

batch_size*seq_len 28672 14336 7168 3584 1792 896 448 224 112 56
时间 (ns):                    
nv-mlperf-1.1 193148.0253 101705.2163 46974.2088 31105.4253 17486.40775 14277.5693 12606.103 12650.391 10847.596 10536.81
paddle 180305.741 93799.7805 51218.6403 31803.7753 21651.55675 16031.1558 13858.423 12667.184 12474.513 12245.71
paddle-opt 154923.7505 80434.30725 41667.2393 25594.6238 17437.93475 12359.7128 10222.886 9237.2238 9086.0273 8898.729
加速比:                    
nv/paddle 1.071225044 1.084279896 0.91713112 0.97804192 0.807628197 0.89061384 0.9096348 0.9986743 0.8695807 0.860449
nv/paddle-opt 1.246729598 1.264450702 1.12736552 1.2153109 1.002779744 1.15516999 1.2331257 1.3695014 1.1938767 1.184081

结论:优化后,fused_dropout_reisudal_layer_norm模块相比竞品,大约获得10%-20%左右加速比。

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -19,6 +19,9 @@ limitations under the License. */
namespace paddle {
namespace operators {

#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer using function instead of macro.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done by using std::ceil function.

@@ -19,6 +19,9 @@ limitations under the License. */
namespace paddle {
namespace operators {

#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
#define COLS_ 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about naming this with a more meaningful name? The name COLS_ is too simple to understand the exact meaning and easy to be conflict with the other macros.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
void *__restrict__ y_, void *__restrict__ residual_out_,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The void * is too hard to read. Try to just write T *, U * or anything else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

void *__restrict__ y_, void *__restrict__ mean_out_,
void *__restrict__ var_out_, const void *__restrict__ x_,
const void *__restrict__ gamma_, const void *__restrict__ beta_,
const float epsilon, int rows, int cols) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid to use void *. Same above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

@sneaxiy sneaxiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants