diff --git a/caffe2/operators/tanh_op.cc b/caffe2/operators/tanh_op.cc new file mode 100644 index 00000000000..8612aa35f2d --- /dev/null +++ b/caffe2/operators/tanh_op.cc @@ -0,0 +1,50 @@ +#include "caffe2/operators/tanh_op.h" + +namespace caffe2 { + +float tanh(float x) { + if (x >= 0) { + float enx = exp(-2.0*x); + return (1 - enx)/(1 + enx); + } else { + float epx = exp(2.0*x); + return (epx - 1)/(epx + 1); + } +} + +template <> +bool TanhOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_GT(X.size(), 0); + Y->ReshapeLike(X); + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + for (int i = 0; i < X.size(); ++i) { + Ydata[i] = tanh(Xdata[i]); + } + return true; +} + +template <> +bool TanhGradientOp::RunOnDevice() { + auto& Y = Input(0); + auto& dY = Input(1); + auto* dX = Output(0); + DCHECK_GT(Y.size(), 0); + DCHECK_EQ(dY.size(), Y.size()); + dX->ReshapeLike(Y); + const float* Ydata = Y.data(); + const float* dYdata = dY.data(); + float* dXdata = dX->mutable_data(); + for (int i = 0; i < dX.size(); ++i) { + dXdata[i] = dYdata[i]*(1 - Ydata[i]*Ydata[i]); + } + return true; +} + +namespace { +REGISTER_CPU_OPERATOR(Tanh, TanhOp) +REGISTER_CPU_OPERATOR(TanhGradient, TanhGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/tanh_op.cu b/caffe2/operators/tanh_op.cu new file mode 100644 index 00000000000..e7b28afeb5b --- /dev/null +++ b/caffe2/operators/tanh_op.cu @@ -0,0 +1,62 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/tanh_op.h" + +namespace caffe2 { +namespace { +__device__ float tanh(float x) { + if (x >= 0) { + float enx = exp(-2.0*x); + return (1 - enx)/(1 + enx); + } else { + float epx = exp(2.0*x); + return (epx - 1)/(epx + 1); + } +} + +template +__global__ void TanhKernel(const int N, const dtype* X, dtype* Y) { + CUDA_1D_KERNEL_LOOP(i, N) { + Y[i] = tanh(X[i]); + } +} + +template +__global__ void TanhGradientKernel(const int N, const dtype* Y, const dtype* dY, + dtype* dX) { + CUDA_1D_KERNEL_LOOP(i, N) { + dX[i] = dY[i]*(1 - Y[i]*Y[i]); + } +} +} // namespace + +template <> +bool TanhOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_GT(X.size(), 0); + Y->ReshapeLike(X); + TanhKernel<<>>( + X.size(), X.data(), Y->mutable_data()); + return true; +} + +template <> +bool TanhGradientOp::RunOnDevice() { + auto& Y = Input(0); + auto& dY = Input(1); + auto* dX = Output(0); + DCHECK_GT(Y.size(), 0); + DCHECK_EQ(dY.size(), Y.size()); + dX->ReshapeLike(Y); + TanhGradientKernel<<>>( + Y.size(), Y.data(), dY.data(), dX->mutable_data()); + return true; +} + +namespace { +REGISTER_CUDA_OPERATOR(Tanh, TanhOp) +REGISTER_CUDA_OPERATOR(TanhGradient, TanhGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/tanh_op.h b/caffe2/operators/tanh_op.h new file mode 100644 index 00000000000..1893a4d0049 --- /dev/null +++ b/caffe2/operators/tanh_op.h @@ -0,0 +1,40 @@ +#ifndef CAFFE2_OPERATORS_TANH_OP_H_ +#define CAFFE2_OPERATORS_TANH_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +template +class TanhOp final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(TanhOp); + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice(); + + protected: + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(TanhOp); +}; + +template +class TanhGradientOp final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(TanhGradientOp); + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice(); + + protected: + // Input: X, dY; Output: dX + INPUT_OUTPUT_STATS(2, 2, 1, 1); + DISABLE_COPY_AND_ASSIGN(TanhGradientOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_TANH_OP_H_ diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h index b5a03e232fd..ac7a06aa7d1 100644 --- a/caffe2/utils/math.h +++ b/caffe2/utils/math.h @@ -21,6 +21,8 @@ template void Log(const int N, const T* x, T* y, DeviceContext* context); template void Sqr(const int N, const T* x, T* y, DeviceContext* context); +template +void Tanh(const int N, const T* x, T* y, DeviceContext* context); template void Powx(const int N, const T* a, const T b, T* y, DeviceContext* context); diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index 268fc31eea0..66cec376fd6 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -38,6 +38,8 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, vsLn) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, vdLn) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, vsSqr) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, vdSqr) +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Tanh, vsTanh) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Tanh, vdTanh) template <> void Powx( diff --git a/caffe2/utils/mkl_alternate.h b/caffe2/utils/mkl_alternate.h index 340b0791b7a..13bff49eb05 100644 --- a/caffe2/utils/mkl_alternate.h +++ b/caffe2/utils/mkl_alternate.h @@ -33,10 +33,21 @@ extern "C" { v##name(n, a, y); \ } +float tanh(const float& x) { + if (x >= 0) { + float enx = exp(-2.0f*x); + return (1.0f - enx)/(1.0f + enx); + } else { + float epx = exp(2.0f*x); + return (epx - 1.0f)/(epx + 1.0f); + } +} + DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]); DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i])); DEFINE_VSL_UNARY_FUNC(Ln, y[i] = std::log(a[i])); DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i])); +DEFINE_VSL_UNARY_FUNC(Tanh, y[i] = tanh(a[i])); // A simple way to define the vsl unary functions with singular parameter b. // The operation should be in the form e.g. y[i] = pow(a[i], b) diff --git a/pycaffe2/core_gradients.py b/pycaffe2/core_gradients.py index dca083bea8f..dea9e73ff91 100644 --- a/pycaffe2/core_gradients.py +++ b/pycaffe2/core_gradients.py @@ -26,6 +26,12 @@ def AddSoftmaxGradient(op): [op.output[0], GetGradientName(op.output[0])], [GetGradientName(op.input[0])]) +@GradientRegistry.RegisterGradient("Tanh") +def AddSoftmaxGradient(op): + return CreateOperator('TanhGradient')( + [op.output[0], GetGradientName(op.output[0])], + [GetGradientName(op.input[0])]) + @GradientRegistry.RegisterGradient("Flatten") def AddFlattenGradient(op): return CreateOperator('ReshapeLike')(