Skip to content

Commit

Permalink
add gradient kernel of det op and slogdet op (#36013)
Browse files Browse the repository at this point in the history
* add gradient kernel of det op and slogdet op

* fix CI APPROVAL problem
  • Loading branch information
thisjiang committed Sep 24, 2021
1 parent 787273e commit b91e8ee
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 75 deletions.
11 changes: 7 additions & 4 deletions paddle/fluid/operators/determinant_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class DeterminantGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
"DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
framework::GradVarName("Input"), "DeterminantGradOp");

Expand Down Expand Up @@ -117,7 +119,8 @@ class SlogDeterminantGradOp : public framework::OperatorWithKernel {
"SlogDeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out",
"SlogDeterminantGradOp");

OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "SlogDeterminantGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
framework::GradVarName("Input"), "SlogDeterminantGradOp");

Expand Down Expand Up @@ -179,13 +182,13 @@ REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
ops::SlogDeterminantGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(slogdeterminant_grad,
ops::DeterminantGradOp) // reuse det grad op
ops::SlogDeterminantGradOp) // reuse det grad op

REGISTER_OP_CPU_KERNEL(
slogdeterminant, ops::SlogDeterminantKernel<plat::CPUDeviceContext, float>,
ops::SlogDeterminantKernel<plat::CPUDeviceContext, double>);

REGISTER_OP_CPU_KERNEL(
slogdeterminant_grad,
ops::DeterminantGradKernel<plat::CPUDeviceContext, float>,
ops::DeterminantGradKernel<plat::CPUDeviceContext, double>);
ops::SlogDeterminantGradKernel<plat::CPUDeviceContext, float>,
ops::SlogDeterminantGradKernel<plat::CPUDeviceContext, double>);
36 changes: 0 additions & 36 deletions paddle/fluid/operators/determinant_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,6 @@ limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/determinant_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;

template <typename T>
__global__ void DeterminantGrad(const size_t numel, T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < numel) {
out[tid] = static_cast<T>(1);
}
}

template <typename T>
class DeterminantGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
const T* dout_data = dout->data<T>();
auto dout_dim = vectorize(dout->dims());

auto* dx = context.Output<Tensor>(framework::GradVarName("Input"));
T* dx_data = dx->mutable_data<T>(context.GetPlace());

int64_t numel = dx->numel();
for (int64_t idx = 0; idx < numel; idx++) {
dx_data[idx] = static_cast<T>(1);
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
Expand Down
Loading

0 comments on commit b91e8ee

Please sign in to comment.