Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen] Move grad GetExpectedPtenKernelArgs into pten #39418

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1171,8 +1171,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
std::string pt_kernel_name;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
pt_kernel_signature_.reset(new KernelSignature(
std::move(this->GetExpectedPtenKernelArgs(exe_ctx))));
pt_kernel_signature_.reset(
new KernelSignature(std::move(GetExpectedPtenKernelArgs(exe_ctx))));
VLOG(6) << *pt_kernel_signature_.get();

kernel_type_.reset(
Expand Down Expand Up @@ -1359,7 +1359,7 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
pten::KernelKey OperatorWithKernel::ChoosePtenKernel(
const ExecutionContext& ctx) const {
pt_kernel_signature_.reset(
new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx))));
new KernelSignature(std::move(GetExpectedPtenKernelArgs(ctx))));
VLOG(6) << *pt_kernel_signature_.get();

kernel_type_.reset(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ class OperatorWithKernel : public OperatorBase {
* When selecting Kernel during Op execution, select the arguments of the
* original Op according to the GetExpectedPtenKernelArgs returned arguments.
*/
virtual pten::KernelSignature GetExpectedPtenKernelArgs(
pten::KernelSignature GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const;

/* member functions for adapting to pten lib */
Expand Down
7 changes: 0 additions & 7 deletions paddle/fluid/operators/digamma_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ class DigammaGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("digamma_grad",
{framework::GradVarName("Out"), "X"}, {},
{framework::GradVarName("X")});
}
};

template <typename T>
Expand Down
7 changes: 0 additions & 7 deletions paddle/fluid/operators/dot_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,6 @@ class DotGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"dot_grad", {"X", "Y", framework::GradVarName("Out")}, {},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
};

template <typename T>
Expand Down
12 changes: 0 additions & 12 deletions paddle/fluid/operators/elementwise/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,18 +353,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout());
}
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (Type() == "elementwise_add_grad") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature(
"add_grad", {"X", "Y", framework::GradVarName("Out")}, {"axis"},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
}

return framework::KernelSignature("None", {"X"}, {}, {"Out"});
}
};

class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
Expand Down
6 changes: 0 additions & 6 deletions paddle/fluid/operators/flatten_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,6 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("flatten_grad",
{framework::GradVarName("Out"), "XShape"},
{}, {framework::GradVarName("X")});
}
};
DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer,
Expand Down
24 changes: 0 additions & 24 deletions paddle/fluid/operators/matmul_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout());
}
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_grad", {"X", "Y", framework::GradVarName("Out")},
{"trans_x", "trans_y"},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
};

template <typename T>
Expand Down Expand Up @@ -439,13 +431,6 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel {
context->ShareDim("DOut", "DDOut");
}
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"},
{"trans_x", "trans_y"}, {"DX", "DY", "DDOut"});
}
};

template <typename T>
Expand Down Expand Up @@ -515,15 +500,6 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel {
context->ShareDim("Y", "D_DDY_out");
}
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"trans_x", "trans_y"},
{"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"});
}
};

template <typename T>
Expand Down
28 changes: 0 additions & 28 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,34 +547,6 @@ class ReduceOp : public framework::OperatorWithKernel {
}
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
bool reduce_all = ctx.Attr<bool>("reduce_all");
if (Type() == "reduce_sum") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (!reduce_all) {
return framework::KernelSignature(
"sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"});
}
return framework::KernelSignature(
"sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
}
}
if (Type() == "reduce_mean") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (!reduce_all) {
return framework::KernelSignature("mean", {"X"}, {"dim", "keep_dim"},
{"Out"});
}
return framework::KernelSignature(
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
}
// TODO(chentianyu03): support other cases after selected rows added
return framework::KernelSignature("reduce.unregistered", {}, {}, {});
}
};

class ReduceOpUseInputPlace : public ReduceOp {
Expand Down
12 changes: 0 additions & 12 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("reshape_grad",
{framework::GradVarName("Out")}, {},
{framework::GradVarName("X")});
}
};

class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
Expand Down Expand Up @@ -622,11 +615,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("reshape_double_grad", {"DDX"}, {},
{"DDOut"});
}
};

DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"});
Expand Down
27 changes: 27 additions & 0 deletions paddle/pten/ops/compat/digamma_sig.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/core/compat/op_utils.h"

namespace pten {

KernelSignature DigammaGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"digamma_grad", {GradVarName("Out"), "X"}, {}, {GradVarName("X")});
}

} // namespace pten

PT_REGISTER_ARG_MAPPING_FN(digamma_grad, pten::DigammaGradOpArgumentMapping);
28 changes: 28 additions & 0 deletions paddle/pten/ops/compat/dot_sig.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/core/compat/op_utils.h"

namespace pten {

KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dot_grad",
{"X", "Y", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Y")});
}

} // namespace pten

PT_REGISTER_ARG_MAPPING_FN(dot_grad, pten::DotGradOpArgumentMapping);
14 changes: 13 additions & 1 deletion paddle/pten/ops/compat/elementwise_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,24 @@ KernelSignature ElementwiseDivOpArgumentMapping(
return KernelSignature("unregistered", {}, {}, {});
}

KernelSignature ElementwiseAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
return KernelSignature("add_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
}
return KernelSignature("unregistered", {}, {}, {});
}

} // namespace pten

PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);

PT_REGISTER_ARG_MAPPING_FN(elementwise_add,
pten::ElementwiseAddOpArgumentMapping);
Expand All @@ -81,3 +91,5 @@ PT_REGISTER_ARG_MAPPING_FN(elementwise_mul,
pten::ElementwiseMulOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_div,
pten::ElementwiseDivOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
pten::ElementwiseAddGradOpArgumentMapping);
8 changes: 8 additions & 0 deletions paddle/pten/ops/compat/flatten_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,18 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}

KernelSignature FlattenGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"flatten_grad", {GradVarName("Out"), "XShape"}, {}, {GradVarName("X")});
}

} // namespace pten

PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range, flatten);
PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range_grad, flatten_grad);

PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range,
pten::FlattenOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range_grad,
pten::FlattenGradOpArgumentMapping);
34 changes: 33 additions & 1 deletion paddle/pten/ops/compat/matmul_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,41 @@ limitations under the License. */

#include "paddle/pten/core/compat/op_utils.h"

namespace pten {} // namespace pten
namespace pten {

KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("matmul_grad",
{"X", "Y", GradVarName("Out")},
{"trans_x", "trans_y"},
{GradVarName("X"), GradVarName("Y")});
}

KernelSignature MatmulDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("matmul_double_grad",
{"X", "Y", "DOut", "DDX", "DDY"},
{"trans_x", "trans_y"},
{"DX", "DY", "DDOut"});
}

KernelSignature MatmulTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"matmul_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"trans_x", "trans_y"},
{"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"});
}

} // namespace pten

PT_REGISTER_BASE_KERNEL_NAME(matmul_v2, matmul);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad, matmul_grad);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad_grad, matmul_double_grad);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_triple_grad, matmul_triple_grad);

PT_REGISTER_ARG_MAPPING_FN(matmul_v2_grad, pten::MatmulGradOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(matmul_v2_grad_grad,
pten::MatmulDoubleGradOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(matmul_v2_triple_grad,
pten::MatmulTripleGradOpArgumentMapping);
2 changes: 1 addition & 1 deletion paddle/pten/ops/compat/reduce_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) {
return KernelSignature(
"sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"});
"sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"});
}
return KernelSignature("sum_raw",
{"X"},
Expand Down
14 changes: 14 additions & 0 deletions paddle/pten/ops/compat/reshape_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,24 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}

KernelSignature ReshapeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"reshape_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
}

KernelSignature ReshapeDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("reshape_double_grad", {"DDX"}, {}, {"DDOut"});
}

} // namespace pten

PT_REGISTER_BASE_KERNEL_NAME(reshape2, reshape);
PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad, reshape_grad);
PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad_grad, reshape_double_grad);

PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reshape2_grad, pten::ReshapeGradOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reshape2_grad_grad,
pten::ReshapeDoubleGradOpArgumentMapping);