Skip to content

Commit

Permalink
[cherry-pick2.3]fix compile bug of windows cuda11.5
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Apr 6, 2022
1 parent 5b85f3d commit 825b6e1
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1878,12 +1878,17 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {

template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// exp(x) = expf(x)
__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(expf(static_cast<float>(x)));
}
};

template <>
struct CudaExpFunctor<double> : public BaseActivationFunctor<double> {
// exp(x) = exp(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(exp(x));
__device__ __forceinline__ double operator()(const double x) const {
return exp(x);
}
};

Expand Down

0 comments on commit 825b6e1

Please sign in to comment.