Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add leaky_relu forward and backward in activation_op.cu #31841

Merged
168 changes: 145 additions & 23 deletions paddle/fluid/operators/activation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,22 @@ template <typename T>
class BaseGPUFunctor {
public:
using ELEMENT_TYPE = T;

using AttrPair = std::vector<std::pair<const char*, float*>>;

AttrPair GetAttrs() { return AttrPair(); }
};

/* ========================================================================== */

/* =========================== relu forward ============================ */
template <typename T>
class ReluGPUFuctor : public BaseGPUFunctor<T> {
class ReluGPUFunctor : public BaseGPUFunctor<T> {
private:
T zero_;

public:
ReluGPUFuctor() { zero_ = static_cast<T>(0.0f); }
ReluGPUFunctor() { zero_ = static_cast<T>(0.0f); }

// for relu forward when T is double
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
Expand All @@ -67,7 +71,7 @@ class ReluGPUFuctor : public BaseGPUFunctor<T> {

template <>
__device__ __forceinline__ CudaVecType<double>::type
ReluGPUFuctor<double>::Compute(const CudaVecType<double>::type* x) {
ReluGPUFunctor<double>::Compute(const CudaVecType<double>::type* x) {
// relu forward : out = max(x, 0)
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 || CUDA_VERSION >= 300
return __ldg(x) > zero_ ? __ldg(x) : zero_;
Expand All @@ -78,15 +82,15 @@ ReluGPUFuctor<double>::Compute(const CudaVecType<double>::type* x) {

template <>
__device__ __forceinline__ CudaVecType<float>::type
ReluGPUFuctor<float>::Compute(const CudaVecType<float>::type* xx) {
ReluGPUFunctor<float>::Compute(const CudaVecType<float>::type* xx) {
// relu forward : out = max(xx, 0)
return make_float4((xx->x > zero_) * (xx->x), (xx->y > zero_) * (xx->y),
(xx->z > zero_) * (xx->z), (xx->w > zero_) * (xx->w));
}

template <>
__device__ __forceinline__ CudaVecType<float16>::type
ReluGPUFuctor<float16>::Compute(const CudaVecType<float16>::type* in) {
ReluGPUFunctor<float16>::Compute(const CudaVecType<float16>::type* in) {
// relu forward : out = max(in, 0)
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 || CUDA_VERSION >= 300
const half2 kzero = __float2half2_rn(0.0f);
Expand Down Expand Up @@ -162,6 +166,113 @@ ReluGradGPUFunctor<float16>::Compute(const CudaVecType<float16>::type* out,
#endif
}

/* ========================================================================== */
/* ======================== leaky relu forward ========================
*/
template <typename T>
class LeakyReluGPUFunctor : public BaseGPUFunctor<T> {
private:
T zero_;
float alpha_;

public:
LeakyReluGPUFunctor() { zero_ = static_cast<T>(0.0f); }

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha_}};
}
// leakyrelu forward : out = x > 0 ? x : x * alpha
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type* x) {
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
return __ldg(x) > zero_ ? __ldg(x) : static_cast<T>(alpha_) * __ldg(x);
#else
return (*x) > zero_ ? (*x) : static_cast<T>(alpha_) * (*x);
#endif
}

__device__ __forceinline__ T ComputeRemainder(const T x) {
return x > zero_ ? x : static_cast<T>(alpha_) * x;
}
};

template <>
__device__ __forceinline__ CudaVecType<float>::type
LeakyReluGPUFunctor<float>::Compute(const CudaVecType<float>::type* xx) {
// leakyrelu forward : out = x > 0 ? x : x * alpha
return make_float4((xx->x > zero_) ? (xx->x) : (xx->x) * alpha_,
(xx->y > zero_) ? (xx->y) : (xx->y) * alpha_,
(xx->z > zero_) ? (xx->z) : (xx->z) * alpha_,
(xx->w > zero_) ? (xx->w) : (xx->w) * alpha_);
}

template <>
__device__ __forceinline__ CudaVecType<float16>::type
LeakyReluGPUFunctor<float16>::Compute(const CudaVecType<float16>::type* in) {
// leakyrelu forward : out = x > 0 ? x : x * alpha
const float2 xx = __half22float2(*in);
return __floats2half2_rn((xx.x > 0.0f) ? xx.x : xx.x * alpha_,
(xx.y > 0.0f) ? xx.y : xx.y * alpha_);
}
/* ========================================================================== */

/* =========================== leaky relu backward =======================
*/
template <typename T>
class LeakyReluGradGPUFunctor : public BaseGPUFunctor<T> {
private:
T zero_;
float alpha_;

public:
LeakyReluGradGPUFunctor() { zero_ = static_cast<T>(0.0f); }

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha_}};
}

// for leaky relu backward when T is double
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type* out,
const typename CudaVecType<T>::type* dout) {
// leakyrelu backward : out = out > 0 ? dout : alpha
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
return __ldg(out) > zero_ ? __ldg(dout) : static_cast<T>(alpha_);
#else
return (*out) > zero_ ? (*dout) : static_cast<T>(alpha_);
#endif
}

// when num % vecsize != 0 this func will be used
__device__ __forceinline__ T ComputeRemainder(const T out, const T dout) {
// relu backward : dx = out > 0 ? dout : 0;
return out > zero_ ? dout : static_cast<T>(alpha_);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <>
__device__ __forceinline__ CudaVecType<float>::type
LeakyReluGradGPUFunctor<float>::Compute(const CudaVecType<float>::type* out,
const CudaVecType<float>::type* dout) {
// relu backward : dx = out > 0 ? dout : 0;
return make_float4((out->x > zero_) ? (dout->x) : alpha_,
(out->y > zero_) ? (dout->y) : alpha_,
(out->z > zero_) ? (dout->z) : alpha_,
(out->w > zero_) ? (dout->w) : alpha_);
}

template <>
__device__ __forceinline__ CudaVecType<float16>::type LeakyReluGradGPUFunctor<
float16>::Compute(const CudaVecType<float16>::type* out,
const CudaVecType<float16>::type* dout) {
const float2 xx = __half22float2(*out);
const float2 yy = __half22float2(*dout);
return __floats2half2_rn((xx.x > 0.0f) ? yy.x : alpha_,
(xx.y > 0.0f) ? yy.y : alpha_);
}

/* ========================================================================== */

template <typename T, typename Functor>
Expand Down Expand Up @@ -231,6 +342,10 @@ class ActivationGPUKernel
block = 256;
#endif
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
constexpr int vecsize = CudaVecType<T>::vecsize;
int grid = max((num / vecsize + block - 1) / block, 1);
ActivationkernelVec<T, Functor><<<grid, block>>>(input_data, output_data,
Expand Down Expand Up @@ -269,7 +384,12 @@ class ActivationGradGPUKernel
#ifdef __HIPCC__
block = 256;
#endif

Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
constexpr int vecsize = CudaVecType<T>::vecsize;
int grid = max((numel / vecsize + block - 1) / block, 1);
ActivationGradKernelVec<T, Functor><<<grid, block>>>(
Expand Down Expand Up @@ -298,12 +418,28 @@ namespace plat = paddle::platform;
ops::grad_functor<double>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);

FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);

#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext, \
ops::functor<float>>, \
ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>, \
ops::ActivationGPUKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, ops::ActivationGradGPUKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradGPUKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>, \
ops::ActivationGradGPUKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);

/* ======================== leaky relu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluGPUFunctor,
LeakyReluGradGPUFunctor);

REGISTER_OP_CUDA_KERNEL(
leaky_relu_grad_grad,
Expand All @@ -328,21 +464,7 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */

/* =========================== relu register ============================ */
REGISTER_OP_CUDA_KERNEL(
relu, ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGPUFuctor<float>>,
ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGPUFuctor<double>>,
ops::ActivationGPUKernel<plat::CUDADeviceContext,
ops::ReluGPUFuctor<plat::float16>>);

REGISTER_OP_CUDA_KERNEL(
relu_grad, ops::ActivationGradGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGPUFunctor<float>>,
ops::ActivationGradGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGPUFunctor<double>>,
ops::ActivationGradGPUKernel<plat::CUDADeviceContext,
ops::ReluGradGPUFunctor<plat::float16>>);
REGISTER_ACTIVATION_GPU_KERNEL(relu, Relu, ReluGPUFunctor, ReluGradGPUFunctor);

REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
Expand Down