Skip to content

Commit

Permalink
[kunlun] prevent overflow in collective softmax_with_ce (#52356)
Browse files Browse the repository at this point in the history
* [kunlun] prevent numerical overflow in collective softmax_with_ce

* add fix in another branch
  • Loading branch information
XiaociZhang authored Mar 31, 2023
1 parent 4c6ad5c commit fb276f2
Showing 1 changed file with 14 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
};
phi::XPUElementwise<T, XPUType>(
dev_ctx, logits_2d, logits_max, axis, &softmax_2d, f);
ret = xpu::clip<XPUType>(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
N * D,
-64.,
0.);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "clip");
}

// step 3, obtain predict target
Expand Down Expand Up @@ -322,6 +329,13 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
};
phi::XPUElementwise<T, XPUType>(
dev_ctx, logits_2d, logits_max, axis, &softmax_2d, f);
ret = xpu::clip<XPUType>(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
N * D,
-64.,
0.);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "clip");
}

// step 3, obtain predict target
Expand Down

0 comments on commit fb276f2

Please sign in to comment.