This repository has been archived by the owner on Feb 7, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
add tanh_op #14
Closed
Closed
add tanh_op #14
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#include "caffe2/operators/tanh_op.h" | ||
|
||
namespace caffe2 { | ||
|
||
float tanh(float x) { | ||
if (x >= 0) { | ||
float enx = exp(-2.0*x); | ||
return (1 - enx)/(1 + enx); | ||
} else { | ||
float epx = exp(2.0*x); | ||
return (epx - 1)/(epx + 1); | ||
} | ||
} | ||
|
||
template <> | ||
bool TanhOp<float, CPUContext>::RunOnDevice() { | ||
auto& X = Input(0); | ||
auto* Y = Output(0); | ||
DCHECK_GT(X.size(), 0); | ||
Y->ReshapeLike(X); | ||
const float* Xdata = X.data(); | ||
float* Ydata = Y->mutable_data(); | ||
for (int i = 0; i < X.size(); ++i) { | ||
Ydata[i] = tanh(Xdata[i]); | ||
} | ||
return true; | ||
} | ||
|
||
template <> | ||
bool TanhGradientOp<float, CPUContext>::RunOnDevice() { | ||
auto& Y = Input(0); | ||
auto& dY = Input(1); | ||
auto* dX = Output(0); | ||
DCHECK_GT(Y.size(), 0); | ||
DCHECK_EQ(dY.size(), Y.size()); | ||
dX->ReshapeLike(Y); | ||
const float* Ydata = Y.data(); | ||
const float* dYdata = dY.data(); | ||
float* dXdata = dX->mutable_data(); | ||
for (int i = 0; i < dX.size(); ++i) { | ||
dXdata[i] = dYdata[i]*(1 - Ydata[i]*Ydata[i]); | ||
} | ||
return true; | ||
} | ||
|
||
namespace { | ||
REGISTER_CPU_OPERATOR(Tanh, TanhOp<float, CPUContext>) | ||
REGISTER_CPU_OPERATOR(TanhGradient, TanhGradientOp<float, CPUContext>) | ||
} // namespace | ||
} // namespace caffe2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
#include "caffe2/core/context_gpu.h" | ||
#include "caffe2/operators/tanh_op.h" | ||
|
||
namespace caffe2 { | ||
namespace { | ||
__device__ float tanh(float x) { | ||
if (x >= 0) { | ||
float enx = exp(-2.0*x); | ||
return (1 - enx)/(1 + enx); | ||
} else { | ||
float epx = exp(2.0*x); | ||
return (epx - 1)/(epx + 1); | ||
} | ||
} | ||
|
||
template <typename dtype> | ||
__global__ void TanhKernel(const int N, const dtype* X, dtype* Y) { | ||
CUDA_1D_KERNEL_LOOP(i, N) { | ||
Y[i] = tanh(X[i]); | ||
} | ||
} | ||
|
||
template <typename dtype> | ||
__global__ void TanhGradientKernel(const int N, const dtype* Y, const dtype* dY, | ||
dtype* dX) { | ||
CUDA_1D_KERNEL_LOOP(i, N) { | ||
dX[i] = dY[i]*(1 - Y[i]*Y[i]); | ||
} | ||
} | ||
} // namespace | ||
|
||
template <> | ||
bool TanhOp<float, CUDAContext>::RunOnDevice() { | ||
auto& X = Input(0); | ||
auto* Y = Output(0); | ||
DCHECK_GT(X.size(), 0); | ||
Y->ReshapeLike(X); | ||
TanhKernel<<<CAFFE_GET_BLOCKS(X.size()), CAFFE_CUDA_NUM_THREADS, | ||
0, device_context_.cuda_stream()>>>( | ||
X.size(), X.data(), Y->mutable_data()); | ||
return true; | ||
} | ||
|
||
template <> | ||
bool TanhGradientOp<float, CUDAContext>::RunOnDevice() { | ||
auto& Y = Input(0); | ||
auto& dY = Input(1); | ||
auto* dX = Output(0); | ||
DCHECK_GT(Y.size(), 0); | ||
DCHECK_EQ(dY.size(), Y.size()); | ||
dX->ReshapeLike(Y); | ||
TanhGradientKernel<<<CAFFE_GET_BLOCKS(Y.size()), CAFFE_CUDA_NUM_THREADS, | ||
0, device_context_.cuda_stream()>>>( | ||
Y.size(), Y.data(), dY.data(), dX->mutable_data()); | ||
return true; | ||
} | ||
|
||
namespace { | ||
REGISTER_CUDA_OPERATOR(Tanh, TanhOp<float, CUDAContext>) | ||
REGISTER_CUDA_OPERATOR(TanhGradient, TanhGradientOp<float, CUDAContext>) | ||
} // namespace | ||
} // namespace caffe2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef CAFFE2_OPERATORS_TANH_OP_H_ | ||
#define CAFFE2_OPERATORS_TANH_OP_H_ | ||
|
||
#include "caffe2/core/context.h" | ||
#include "caffe2/core/operator.h" | ||
#include "caffe2/utils/math.h" | ||
#include "glog/logging.h" | ||
|
||
namespace caffe2 { | ||
|
||
template <typename dtype, class DeviceContext> | ||
class TanhOp final : public Operator<dtype, DeviceContext> { | ||
public: | ||
USE_SIMPLE_CTOR_DTOR(TanhOp); | ||
USE_OPERATOR_BASE_FUNCTIONS; | ||
|
||
bool RunOnDevice(); | ||
|
||
protected: | ||
INPUT_OUTPUT_STATS(1, 1, 1, 1); | ||
DISABLE_COPY_AND_ASSIGN(TanhOp); | ||
}; | ||
|
||
template <typename dtype, class DeviceContext> | ||
class TanhGradientOp final : public Operator<dtype, DeviceContext> { | ||
public: | ||
USE_SIMPLE_CTOR_DTOR(TanhGradientOp); | ||
USE_OPERATOR_BASE_FUNCTIONS; | ||
|
||
bool RunOnDevice(); | ||
|
||
protected: | ||
// Input: X, dY; Output: dX | ||
INPUT_OUTPUT_STATS(2, 2, 1, 1); | ||
DISABLE_COPY_AND_ASSIGN(TanhGradientOp); | ||
}; | ||
|
||
} // namespace caffe2 | ||
|
||
#endif // CAFFE2_OPERATORS_TANH_OP_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,8 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, vsLn) | |
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, vdLn) | ||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, vsSqr) | ||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, vdSqr) | ||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Tanh, vsTanh) | ||
DELEGATE_SIMPLE_UNARY_FUNCTION(double, Tanh, vdTanh) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (not sure) is MKL using vsTanh or vstanh? (i.e. capital or non-capital?) |
||
|
||
template <> | ||
void Powx<float, CPUContext>( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,12 @@ def AddSoftmaxGradient(op): | |
[op.output[0], GetGradientName(op.output[0])], | ||
[GetGradientName(op.input[0])]) | ||
|
||
@GradientRegistry.RegisterGradient("Tanh") | ||
def AddSoftmaxGradient(op): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AddSoftmaxGradient -> AddTanhGradient |
||
return CreateOperator('TanhGradient')( | ||
[op.output[0], GetGradientName(op.output[0])], | ||
[GetGradientName(op.input[0])]) | ||
|
||
@GradientRegistry.RegisterGradient("Flatten") | ||
def AddFlattenGradient(op): | ||
return CreateOperator('ReshapeLike')( | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Input should be "Y, dY" instead of "X, dY"?