Skip to content

Commit

Permalink
[PTen] Move grad GetExpectedPtenKernelArgs into pten (#39418)
Browse files Browse the repository at this point in the history
* move grad get expected pten kernel args

* fix reduce sum error

* fix element_sub_grad failed

* revert kernel judge change
  • Loading branch information
chenwhql authored Feb 11, 2022
1 parent 22c67d1 commit 667bd96
Show file tree
Hide file tree
Showing 16 changed files with 128 additions and 103 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1171,8 +1171,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
std::string pt_kernel_name;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
pt_kernel_signature_.reset(new KernelSignature(
std::move(this->GetExpectedPtenKernelArgs(exe_ctx))));
pt_kernel_signature_.reset(
new KernelSignature(std::move(GetExpectedPtenKernelArgs(exe_ctx))));
VLOG(6) << *pt_kernel_signature_.get();

kernel_type_.reset(
Expand Down Expand Up @@ -1359,7 +1359,7 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
pten::KernelKey OperatorWithKernel::ChoosePtenKernel(
const ExecutionContext& ctx) const {
pt_kernel_signature_.reset(
new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx))));
new KernelSignature(std::move(GetExpectedPtenKernelArgs(ctx))));
VLOG(6) << *pt_kernel_signature_.get();

kernel_type_.reset(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ class OperatorWithKernel : public OperatorBase {
* When selecting Kernel during Op execution, select the arguments of the
* original Op according to the GetExpectedPtenKernelArgs returned arguments.
*/
virtual pten::KernelSignature GetExpectedPtenKernelArgs(
pten::KernelSignature GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const;

/* member functions for adapting to pten lib */
Expand Down
7 changes: 0 additions & 7 deletions paddle/fluid/operators/digamma_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ class DigammaGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("digamma_grad",
{framework::GradVarName("Out"), "X"}, {},
{framework::GradVarName("X")});
}
};

template <typename T>
Expand Down
7 changes: 0 additions & 7 deletions paddle/fluid/operators/dot_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,6 @@ class DotGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"dot_grad", {"X", "Y", framework::GradVarName("Out")}, {},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
};

template <typename T>
Expand Down
12 changes: 0 additions & 12 deletions paddle/fluid/operators/elementwise/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,18 +353,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout());
}
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (Type() == "elementwise_add_grad") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature(
"add_grad", {"X", "Y", framework::GradVarName("Out")}, {"axis"},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
}

return framework::KernelSignature("None", {"X"}, {}, {"Out"});
}
};

class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
Expand Down
6 changes: 0 additions & 6 deletions paddle/fluid/operators/flatten_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,6 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("flatten_grad",
{framework::GradVarName("Out"), "XShape"},
{}, {framework::GradVarName("X")});
}
};
DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer,
Expand Down
24 changes: 0 additions & 24 deletions paddle/fluid/operators/matmul_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout());
}
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_grad", {"X", "Y", framework::GradVarName("Out")},
{"trans_x", "trans_y"},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
};

template <typename T>
Expand Down Expand Up @@ -439,13 +431,6 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel {
context->ShareDim("DOut", "DDOut");
}
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"},
{"trans_x", "trans_y"}, {"DX", "DY", "DDOut"});
}
};

template <typename T>
Expand Down Expand Up @@ -515,15 +500,6 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel {
context->ShareDim("Y", "D_DDY_out");
}
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"trans_x", "trans_y"},
{"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"});
}
};

template <typename T>
Expand Down
28 changes: 0 additions & 28 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,34 +547,6 @@ class ReduceOp : public framework::OperatorWithKernel {
}
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
bool reduce_all = ctx.Attr<bool>("reduce_all");
if (Type() == "reduce_sum") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (!reduce_all) {
return framework::KernelSignature(
"sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"});
}
return framework::KernelSignature(
"sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
}
}
if (Type() == "reduce_mean") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (!reduce_all) {
return framework::KernelSignature("mean", {"X"}, {"dim", "keep_dim"},
{"Out"});
}
return framework::KernelSignature(
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
}
// TODO(chentianyu03): support other cases after selected rows added
return framework::KernelSignature("reduce.unregistered", {}, {}, {});
}
};

class ReduceOpUseInputPlace : public ReduceOp {
Expand Down
12 changes: 0 additions & 12 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("reshape_grad",
{framework::GradVarName("Out")}, {},
{framework::GradVarName("X")});
}
};

class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
Expand Down Expand Up @@ -622,11 +615,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("reshape_double_grad", {"DDX"}, {},
{"DDOut"});
}
};

DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"});
Expand Down
27 changes: 27 additions & 0 deletions paddle/pten/ops/compat/digamma_sig.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/core/compat/op_utils.h"

namespace pten {

KernelSignature DigammaGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"digamma_grad", {GradVarName("Out"), "X"}, {}, {GradVarName("X")});
}

} // namespace pten

PT_REGISTER_ARG_MAPPING_FN(digamma_grad, pten::DigammaGradOpArgumentMapping);
28 changes: 28 additions & 0 deletions paddle/pten/ops/compat/dot_sig.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/core/compat/op_utils.h"

namespace pten {

KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dot_grad",
{"X", "Y", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Y")});
}

} // namespace pten

PT_REGISTER_ARG_MAPPING_FN(dot_grad, pten::DotGradOpArgumentMapping);
14 changes: 13 additions & 1 deletion paddle/pten/ops/compat/elementwise_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,24 @@ KernelSignature ElementwiseDivOpArgumentMapping(
return KernelSignature("unregistered", {}, {}, {});
}

KernelSignature ElementwiseAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
return KernelSignature("add_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
}
return KernelSignature("unregistered", {}, {}, {});
}

} // namespace pten

PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);

PT_REGISTER_ARG_MAPPING_FN(elementwise_add,
pten::ElementwiseAddOpArgumentMapping);
Expand All @@ -81,3 +91,5 @@ PT_REGISTER_ARG_MAPPING_FN(elementwise_mul,
pten::ElementwiseMulOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_div,
pten::ElementwiseDivOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
pten::ElementwiseAddGradOpArgumentMapping);
8 changes: 8 additions & 0 deletions paddle/pten/ops/compat/flatten_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,18 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}

KernelSignature FlattenGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"flatten_grad", {GradVarName("Out"), "XShape"}, {}, {GradVarName("X")});
}

} // namespace pten

PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range, flatten);
PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range_grad, flatten_grad);

PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range,
pten::FlattenOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range_grad,
pten::FlattenGradOpArgumentMapping);
34 changes: 33 additions & 1 deletion paddle/pten/ops/compat/matmul_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,41 @@ limitations under the License. */

#include "paddle/pten/core/compat/op_utils.h"

namespace pten {} // namespace pten
namespace pten {

KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("matmul_grad",
{"X", "Y", GradVarName("Out")},
{"trans_x", "trans_y"},
{GradVarName("X"), GradVarName("Y")});
}

KernelSignature MatmulDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("matmul_double_grad",
{"X", "Y", "DOut", "DDX", "DDY"},
{"trans_x", "trans_y"},
{"DX", "DY", "DDOut"});
}

KernelSignature MatmulTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"matmul_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"trans_x", "trans_y"},
{"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"});
}

} // namespace pten

PT_REGISTER_BASE_KERNEL_NAME(matmul_v2, matmul);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad, matmul_grad);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad_grad, matmul_double_grad);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_triple_grad, matmul_triple_grad);

PT_REGISTER_ARG_MAPPING_FN(matmul_v2_grad, pten::MatmulGradOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(matmul_v2_grad_grad,
pten::MatmulDoubleGradOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(matmul_v2_triple_grad,
pten::MatmulTripleGradOpArgumentMapping);
2 changes: 1 addition & 1 deletion paddle/pten/ops/compat/reduce_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) {
return KernelSignature(
"sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"});
"sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"});
}
return KernelSignature("sum_raw",
{"X"},
Expand Down
14 changes: 14 additions & 0 deletions paddle/pten/ops/compat/reshape_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,24 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}

KernelSignature ReshapeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"reshape_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
}

KernelSignature ReshapeDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("reshape_double_grad", {"DDX"}, {}, {"DDOut"});
}

} // namespace pten

PT_REGISTER_BASE_KERNEL_NAME(reshape2, reshape);
PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad, reshape_grad);
PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad_grad, reshape_double_grad);

PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reshape2_grad, pten::ReshapeGradOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reshape2_grad_grad,
pten::ReshapeDoubleGradOpArgumentMapping);

0 comments on commit 667bd96

Please sign in to comment.