Skip to content
This repository has been archived by the owner on Feb 7, 2023. It is now read-only.

add tanh_op #14

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions caffe2/operators/tanh_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "caffe2/operators/tanh_op.h"

namespace caffe2 {

float tanh(float x) {
if (x >= 0) {
float enx = exp(-2.0*x);
return (1 - enx)/(1 + enx);
} else {
float epx = exp(2.0*x);
return (epx - 1)/(epx + 1);
}
}

template <>
bool TanhOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
DCHECK_GT(X.size(), 0);
Y->ReshapeLike(X);
const float* Xdata = X.data();
float* Ydata = Y->mutable_data();
for (int i = 0; i < X.size(); ++i) {
Ydata[i] = tanh(Xdata[i]);
}
return true;
}

template <>
bool TanhGradientOp<float, CPUContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
DCHECK_GT(Y.size(), 0);
DCHECK_EQ(dY.size(), Y.size());
dX->ReshapeLike(Y);
const float* Ydata = Y.data();
const float* dYdata = dY.data();
float* dXdata = dX->mutable_data();
for (int i = 0; i < dX.size(); ++i) {
dXdata[i] = dYdata[i]*(1 - Ydata[i]*Ydata[i]);
}
return true;
}

namespace {
REGISTER_CPU_OPERATOR(Tanh, TanhOp<float, CPUContext>)
REGISTER_CPU_OPERATOR(TanhGradient, TanhGradientOp<float, CPUContext>)
} // namespace
} // namespace caffe2
62 changes: 62 additions & 0 deletions caffe2/operators/tanh_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/tanh_op.h"

namespace caffe2 {
namespace {
__device__ float tanh(float x) {
if (x >= 0) {
float enx = exp(-2.0*x);
return (1 - enx)/(1 + enx);
} else {
float epx = exp(2.0*x);
return (epx - 1)/(epx + 1);
}
}

template <typename dtype>
__global__ void TanhKernel(const int N, const dtype* X, dtype* Y) {
CUDA_1D_KERNEL_LOOP(i, N) {
Y[i] = tanh(X[i]);
}
}

template <typename dtype>
__global__ void TanhGradientKernel(const int N, const dtype* Y, const dtype* dY,
dtype* dX) {
CUDA_1D_KERNEL_LOOP(i, N) {
dX[i] = dY[i]*(1 - Y[i]*Y[i]);
}
}
} // namespace

template <>
bool TanhOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
DCHECK_GT(X.size(), 0);
Y->ReshapeLike(X);
TanhKernel<<<CAFFE_GET_BLOCKS(X.size()), CAFFE_CUDA_NUM_THREADS,
0, device_context_.cuda_stream()>>>(
X.size(), X.data(), Y->mutable_data());
return true;
}

template <>
bool TanhGradientOp<float, CUDAContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
auto* dX = Output(0);
DCHECK_GT(Y.size(), 0);
DCHECK_EQ(dY.size(), Y.size());
dX->ReshapeLike(Y);
TanhGradientKernel<<<CAFFE_GET_BLOCKS(Y.size()), CAFFE_CUDA_NUM_THREADS,
0, device_context_.cuda_stream()>>>(
Y.size(), Y.data(), dY.data(), dX->mutable_data());
return true;
}

namespace {
REGISTER_CUDA_OPERATOR(Tanh, TanhOp<float, CUDAContext>)
REGISTER_CUDA_OPERATOR(TanhGradient, TanhGradientOp<float, CUDAContext>)
} // namespace
} // namespace caffe2
40 changes: 40 additions & 0 deletions caffe2/operators/tanh_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef CAFFE2_OPERATORS_TANH_OP_H_
#define CAFFE2_OPERATORS_TANH_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
#include "glog/logging.h"

namespace caffe2 {

template <typename dtype, class DeviceContext>
class TanhOp final : public Operator<dtype, DeviceContext> {
public:
USE_SIMPLE_CTOR_DTOR(TanhOp);
USE_OPERATOR_BASE_FUNCTIONS;

bool RunOnDevice();

protected:
INPUT_OUTPUT_STATS(1, 1, 1, 1);
DISABLE_COPY_AND_ASSIGN(TanhOp);
};

template <typename dtype, class DeviceContext>
class TanhGradientOp final : public Operator<dtype, DeviceContext> {
public:
USE_SIMPLE_CTOR_DTOR(TanhGradientOp);
USE_OPERATOR_BASE_FUNCTIONS;

bool RunOnDevice();

protected:
// Input: X, dY; Output: dX
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Input should be "Y, dY" instead of "X, dY"?

INPUT_OUTPUT_STATS(2, 2, 1, 1);
DISABLE_COPY_AND_ASSIGN(TanhGradientOp);
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_TANH_OP_H_
2 changes: 2 additions & 0 deletions caffe2/utils/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ template <typename T, class DeviceContext>
void Log(const int N, const T* x, T* y, DeviceContext* context);
template <typename T, class DeviceContext>
void Sqr(const int N, const T* x, T* y, DeviceContext* context);
template <typename T, class DeviceContext>
void Tanh(const int N, const T* x, T* y, DeviceContext* context);

template <typename T, class DeviceContext>
void Powx(const int N, const T* a, const T b, T* y, DeviceContext* context);
Expand Down
2 changes: 2 additions & 0 deletions caffe2/utils/math_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, vsLn)
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, vdLn)
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, vsSqr)
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, vdSqr)
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Tanh, vsTanh)
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Tanh, vdTanh)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not sure) is MKL using vsTanh or vstanh? (i.e. capital or non-capital?)


template <>
void Powx<float, CPUContext>(
Expand Down
11 changes: 11 additions & 0 deletions caffe2/utils/mkl_alternate.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,21 @@ extern "C" {
v##name<double>(n, a, y); \
}

float tanh(const float& x) {
if (x >= 0) {
float enx = exp(-2.0f*x);
return (1.0f - enx)/(1.0f + enx);
} else {
float epx = exp(2.0f*x);
return (epx - 1.0f)/(epx + 1.0f);
}
}

DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));
DEFINE_VSL_UNARY_FUNC(Ln, y[i] = std::log(a[i]));
DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i]));
DEFINE_VSL_UNARY_FUNC(Tanh, y[i] = tanh(a[i]));

// A simple way to define the vsl unary functions with singular parameter b.
// The operation should be in the form e.g. y[i] = pow(a[i], b)
Expand Down
6 changes: 6 additions & 0 deletions pycaffe2/core_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def AddSoftmaxGradient(op):
[op.output[0], GetGradientName(op.output[0])],
[GetGradientName(op.input[0])])

@GradientRegistry.RegisterGradient("Tanh")
def AddSoftmaxGradient(op):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AddSoftmaxGradient -> AddTanhGradient

return CreateOperator('TanhGradient')(
[op.output[0], GetGradientName(op.output[0])],
[GetGradientName(op.input[0])])

@GradientRegistry.RegisterGradient("Flatten")
def AddFlattenGradient(op):
return CreateOperator('ReshapeLike')(
Expand Down