Skip to content

Commit

Permalink
fix inf in fused_attention (#41933)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding committed Apr 20, 2022
1 parent 4ef0a0b commit e76e1e2
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions paddle/fluid/operators/fused/fmha_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace paddle {
Expand Down Expand Up @@ -117,6 +119,18 @@ class FMHARef {
v_ptr = k_ptr + k_size;
}

{
// NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for
// float16 calculation, INF may appear in QK^T if we do not scale before.
float alpha = 1.0 / sqrt(head_dim_);
auto q_tensor = transpose_2_out_tensor->Slice(0, 1);
auto functor = phi::funcs::ScaleFunctor<T>(alpha);
std::vector<const framework::Tensor*> ins = {&q_tensor};
std::vector<framework::Tensor*> outs = {&q_tensor};
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx_, ins,
&outs, functor);
}

// q*k^t, batched_gemm
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasTrans;
Expand All @@ -125,7 +139,7 @@ class FMHARef {
int gemm_m = seq_len_;
int gemm_n = out_seq_len;
int gemm_k = head_dim_;
T alpha = static_cast<T>(1.0 / sqrt(head_dim_));
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
int64_t stride_a = gemm_m * gemm_k;
int64_t stride_b = gemm_k * gemm_n;
Expand Down Expand Up @@ -300,7 +314,9 @@ class FMHARef {
}

T* qk_out_grad_data = qk_out_grad_tensor->data<T>();
alpha = static_cast<T>(1.0 / sqrt(head_dim_));
// NOTE(wangxi): For we scale Q with 1/sqrt(Dh) in forward, so we set
// alpha = 1.0 in backward.
alpha = static_cast<T>(1.0);
// recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out
// bw: dy (seq_len * head_dim) = (dout)^t * x
transA = CblasTrans;
Expand All @@ -314,6 +330,7 @@ class FMHARef {
qk_out_grad_data, q_ptr, beta, k_grad_ptr, gemm_batch_size,
stride_a, stride_b);
// dx (seq_len * head_dim) = dout * y
alpha = static_cast<T>(1.0 / sqrt(head_dim_));
transA = CblasNoTrans;
transB = CblasNoTrans;
gemm_m = seq_len_;
Expand Down

1 comment on commit e76e1e2

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.