diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index 7bf94e98d6261..aa613dd3f5ce0 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -221,14 +221,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( phi::Load(bias_ptr + col * VecSize, &bias[it]); col += THREADS_PER_ROW; } - } else { -#pragma unroll - for (int it = 0; it < LDGS; it++) { -#pragma unroll - for (int jt = 0; jt < VecSize; jt++) { - bias[it][jt] = static_cast(0.0f); - } - } } Vec_scale gamma[LDGS]; @@ -276,15 +268,28 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( // 4 * 8 U xf[LDGS * VecSize]; + if (bias_ptr != nullptr) { #pragma unroll - for (int it = 0; it < LDGS; it++) { + for (int it = 0; it < LDGS; it++) { #pragma unroll - for (int jt = 0; jt < VecSize; jt++) { - // dropout(x) + residual - x[it][jt] = (x[it][jt] + bias[it][jt]) * - static_cast(mask_vec[it][jt]) * factor + - residual[it][jt]; - xf[it * VecSize + jt] = U(x[it][jt]); + for (int jt = 0; jt < VecSize; jt++) { + // dropout(x) + residual + x[it][jt] = (x[it][jt] + bias[it][jt]) * + static_cast(mask_vec[it][jt]) * factor + + residual[it][jt]; + xf[it * VecSize + jt] = U(x[it][jt]); + } + } + } else { +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + // dropout(x) + residual + x[it][jt] = x[it][jt] * static_cast(mask_vec[it][jt]) * factor + + residual[it][jt]; + xf[it * VecSize + jt] = U(x[it][jt]); + } } }