Skip to content

Commit

Permalink
【Complex OP】No.30 complex stanh op (#57639)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao authored Sep 25, 2023
1 parent 50bb3e1 commit 395ffbf
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 5 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
Expand Down
52 changes: 52 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,32 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct STanhGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
float scale_a;
float scale_b;
typename BaseActivationFunctor<ComplexType<T>>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}

template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
auto a = static_cast<ComplexType<T>>(scale_a); // NOLINT
auto b = static_cast<ComplexType<T>>(scale_b);
auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) =
dout *
(a * b * (static_cast<ComplexType<T>>(1) - temp)).unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct Tangent {
HOSTDEVICE T operator()(const T& val) const { return tan(val); }
Expand Down Expand Up @@ -3578,6 +3604,32 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSTanhGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);
float scale_a;
float scale_b;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}

// dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_dout, const ComplexType<T> arg_x) const {
ComplexType<T> dout = static_cast<ComplexType<T>>(arg_dout);
ComplexType<T> x = static_cast<ComplexType<T>>(arg_x);
ComplexType<T> a = static_cast<ComplexType<T>>(scale_a);
ComplexType<T> b = static_cast<ComplexType<T>>(scale_b);
ComplexType<T> temp = tanh(a * x);
return static_cast<ComplexType<T>>(dout *
conj(a * b * (one - temp * temp)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
ThresholdedReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
Expand Down
18 changes: 17 additions & 1 deletion test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3694,7 +3694,13 @@ def setUp(self):
scale_b = self.get_scale_b()

np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
if self.dtype is np.complex64 or self.dtype is np.complex128:
x = (
np.random.uniform(0.1, 1, self.shape)
+ 1j * np.random.uniform(0.1, 1, self.shape)
).astype(self.dtype)
else:
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
# The same reason with TestAbs
out = ref_stanh(x, scale_a, scale_b)

Expand Down Expand Up @@ -3724,6 +3730,16 @@ def init_shape(self):
self.shape = []


class TestSTanhComplex64(TestSTanh):
def init_dtype(self):
self.dtype = np.complex64


class TestSTanhComplex128(TestSTanh):
def init_dtype(self):
self.dtype = np.complex128


class TestSTanhAPI(unittest.TestCase):
# test paddle.nn.stanh
def get_scale_a(self):
Expand Down

0 comments on commit 395ffbf

Please sign in to comment.