Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzSean committed Apr 26, 2022
1 parent 9130113 commit e4c4dd0
Showing 1 changed file with 20 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
phi::Load<T, VecSize>(bias_ptr + col * VecSize, &bias[it]);
col += THREADS_PER_ROW;
}
} else {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
bias[it][jt] = static_cast<T>(0.0f);
}
}
}

Vec_scale gamma[LDGS];
Expand Down Expand Up @@ -276,15 +268,28 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(

// 4 * 8
U xf[LDGS * VecSize];
if (bias_ptr != nullptr) {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
x[it][jt] = (x[it][jt] + bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
xf[it * VecSize + jt] = U(x[it][jt]);
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
x[it][jt] = (x[it][jt] + bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
xf[it * VecSize + jt] = U(x[it][jt]);
}
}
} else {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
x[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
xf[it * VecSize + jt] = U(x[it][jt]);
}
}
}

Expand Down

1 comment on commit e4c4dd0

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.