Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen]Remove infershape of Reshape OP #39631

Merged
merged 9 commits into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 80 additions & 14 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
Expand Down Expand Up @@ -54,7 +55,12 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
}

size_t InputSize(const std::string& name) const override {
return ctx_.Inputs(name).size();
if (ctx_.HasInputs(name)) {
return ctx_.Inputs(name).size();
} else if (ctx_.HasInput(name)) {
return 1;
}
return 0;
}

size_t OutputSize(const std::string& name) const override {
Expand Down Expand Up @@ -288,6 +294,16 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
auto& attr_names = std::get<1>(signature.args);
auto& output_names = std::get<2>(signature.args);

auto kernels_map =
phi::KernelFactory::Instance().SelectKernelMap(signature.name);
if (kernels_map.size() == 0) {
PADDLE_THROW(
platform::errors::Unimplemented("Not find `%s` kernels when construct "
"InferMetaContext.",
signature.name));
}
auto attr_defs = kernels_map.cbegin()->second.args_def().attribute_defs();

// TODO(chenweihang): support multiple inputs and outputs later
phi::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) {
Expand All @@ -299,9 +315,70 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
}

for (auto& out_name : output_names) {
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}
auto attr_reader = ctx->Attrs();
for (auto& attr_name : attr_names) {
if (ctx->HasAttr(attr_name)) {
for (size_t i = 0; i < attr_names.size(); ++i) {
auto attr_name = attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) {
// When attr is a vector_tensor or tensor, transform it to ScalarArray
if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
if (ctx->IsRuntime()) {
// If is in runtime, we will get tensor's value for ScalarArray
// and push it into attrs
std::vector<Variable*> vars;
vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) {
vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i]));
}
if (infershape_inputs.size() != 1) {
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarArrayFromVarList(vars)));
} else {
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarArrayFromVar(*vars[0])));
}
} else {
// If is not in runtime, we will set default value(-1) for ScalarArray
int64_t num_ele = 1;
std::vector<VarDesc*> vars;
vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) {
vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i]));
}
for (auto& var : vars) {
const auto& tensor_dims = var->GetShape();
for (size_t i = 0; i < tensor_dims.size(); ++i) {
num_ele *= tensor_dims[i];
}
}
phi::ScalarArray tensor_attr(std::vector<int32_t>(num_ele, -1));
tensor_attr.SetFromTensor(true);
infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
}
} else if (ctx->HasAttr(attr_name)) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
infer_meta_context.EmplaceBackAttr(std::move(
phi::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
"construct KernelContext.",
attr_name));
}
}

} else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
Expand Down Expand Up @@ -345,17 +422,6 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
"Unsupported attribute type is received when call "
"InferShapeFunctor."));
}
} else {
// do nothing
}
}

for (auto& out_name : output_names) {
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}

Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/infershape_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/kernel_registry.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -93,6 +96,17 @@ phi::KernelSignature InferShapeUtilsTestOpArgumentMapping(
{});
}

template <typename T, typename Context>
void InferShapeUtilsTestKernel(
const Context& dev_ctx, const phi::DenseTensor& x, bool attr1, int attr2,
int64_t attr3, float attr4, const std::string& attr5,
const std::vector<bool>& attr6, const std::vector<int>& attr7,
const std::vector<int64_t>& attr8, const std::vector<float>& attr9,
const std::vector<double>& attr10, const std::vector<std::string>& attr11,
phi::DenseTensor* out) {
VLOG(6) << "Come into InferShapeUtilsTestKernel";
}

} // namespace framework
} // namespace paddle

Expand All @@ -104,6 +118,9 @@ REGISTER_OPERATOR(infer_shape_utils_test,
paddle::framework::InferShapeUtilsTestOpMaker,
InferShapeUtilsTestInferShapeFunctor);

PT_REGISTER_KERNEL(infer_shape_utils_test, CPU, ALL_LAYOUT,
paddle::framework::InferShapeUtilsTestKernel, int) {}

TEST(InferShapeUtilsTest, ALL) {
paddle::framework::ProgramDesc prog;
paddle::framework::proto::BlockDesc proto_block;
Expand Down
26 changes: 9 additions & 17 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@ limitations under the License. */

#include <string>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"

// only can include the headers in paddle/phi/api dirs
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/reshape_grad_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"

namespace paddle {
namespace framework {
class InferShapeContext;
Expand Down Expand Up @@ -472,22 +476,6 @@ class Reshape2Op : public ReshapeOp {
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ReshapeOp(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
platform::errors::InvalidArgument(
"Output(XShape) of ReshapeOp should not be null."));
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");

ReshapeOp::InferShape(ctx);
}
};

class Reshape2OpMaker : public ReshapeOpMaker {
Expand Down Expand Up @@ -647,10 +635,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);

DELCARE_INFER_SHAPE_FUNCTOR(reshape2, ReshapeInferShapeFunctor,
PT_INFER_META(phi::ReshapeWithXShapeInferMeta));

REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker<paddle::framework::OpDesc>,
ops::Reshape2GradMaker<paddle::imperative::OpBase>,
ops::ReshapeOpInplaceInferer);
ReshapeInferShapeFunctor, ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/lib/utils/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ phi::ScalarArray MakePtenScalarArrayFromVarList(
}

phi::ScalarArray result{vector_data};
result.setInitByTensor(true);
result.SetFromTensor(true);

return result;
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/common/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace experimental {
template <typename T>
class ScalarBase {
public:
bool IsInitByTensor() const { return is_init_by_tensor_; }
bool FromTensor() const { return is_from_tensor_; }
// Constructor support implicit
ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT
data_.f64 = val;
Expand Down Expand Up @@ -104,7 +104,7 @@ class ScalarBase {

// The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
is_init_by_tensor_ = true;
is_from_tensor_ = true;
PD_CHECK(
tensor.numel() == 1,
"The Scalar only supports Tensor with 1 element, but now Tensor has `",
Expand Down Expand Up @@ -196,7 +196,7 @@ class ScalarBase {
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);

private:
bool is_init_by_tensor_{false};
bool is_from_tensor_{false};
DataType dtype_;
union data {
bool b;
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/common/scalar_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ class ScalarArrayBase {
AssignData(date_value, n);
}

bool IsInitByTensor() const { return is_init_by_tensor_; }
bool FromTensor() const { return is_from_tensor_; }

void setInitByTensor(bool val) { is_init_by_tensor_ = val; }
void SetFromTensor(bool val) { is_from_tensor_ = val; }

// The Tensor must have one dim
ScalarArrayBase(const T& tensor) { // NOLINT
is_init_by_tensor_ = true;
is_from_tensor_ = true;
size_t n = tensor.numel();
array_.reserve(n);
switch (tensor.dtype()) {
Expand All @@ -71,7 +71,7 @@ class ScalarArrayBase {

// The Tensor in vec must have only one element
ScalarArrayBase(const std::vector<T>& tensor_list) { // NOLINT
is_init_by_tensor_ = true;
is_from_tensor_ = true;

for (size_t i = 0; i < tensor_list.size(); ++i) {
DataType data_type = tensor_list[i].dtype();
Expand Down Expand Up @@ -117,7 +117,7 @@ class ScalarArrayBase {
// TODO(zhangyunfei) Replace std::vector with a more efficient container
// structure.
std::vector<int64_t> array_;
bool is_init_by_tensor_{false};
bool is_from_tensor_{false};
};

using ScalarArray =
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<std::string>&);

/* Output Helpers */

Expand Down
42 changes: 38 additions & 4 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/phi/infermeta/unary.h"

#include <set>

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"

namespace phi {
Expand Down Expand Up @@ -210,7 +210,7 @@ void InferMetaFromVecValue(const MetaTensor& x,
MetaTensor* out) {
PADDLE_ENFORCE_EQ(!shape.empty(),
true,
paddle::platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The parameter 'shape' in ReshapeOp must be set. "
"But received 'shape' is empty."));
auto x_dims = x.dims();
Expand All @@ -227,8 +227,42 @@ void InferMetaFromVecValue(const MetaTensor& x,

void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* out) {
InferMetaFromVecValue(x, shape.GetData(), out);
MetaTensor* out,
MetaConfig config) {
auto& shape_data = shape.GetData();
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"Output(Out) of ReshapeOp should not be null."));
if (!config.is_runtime && shape.FromTensor()) {
out->set_dims(phi::make_ddim(shape_data));
out->share_lod(x);
return;
}
PADDLE_ENFORCE_GT(shape_data.size(),
0,
phi::errors::InvalidArgument(
"The shape's size in ReshapeOp can't be zero."));
InferMetaFromVecValue(x, shape_data, out);
}

void ReshapeWithXShapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* xshape,
MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL(
xshape,
phi::errors::InvalidArgument(
"Output(XShape) of ReshapeOp should not be null."));
const auto& x_dims = x.dims();
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
ReshapeInferMeta(x, shape, out, config);
}

/* Why not use ReduceInferMeta directly?
Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ void InferMetaFromVecValue(const MetaTensor& x,

void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* out);
MetaTensor* out,
MetaConfig config = MetaConfig());

void ReshapeWithXShapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* xshape,
MetaTensor* out,
MetaConfig config = MetaConfig());

void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/split_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void SplitKernel(const Context& dev_ctx,
const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) {
// need to infershape output
if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) {
if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) {
std::vector<MetaTensor> out_metas;
for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]);
Expand Down
Loading