Skip to content

Commit

Permalink
[PTen]Remove infershape of Reshape OP (#39631)
Browse files Browse the repository at this point in the history
* remove infershape and Xshape

* add xshape

* fix bugs when run ci

* fix bugs when run ci

* fix bugs when run infrt test

* pass converage
  • Loading branch information
YuanRisheng authored Feb 21, 2022
1 parent f858b64 commit 45dd4a5
Show file tree
Hide file tree
Showing 14 changed files with 181 additions and 60 deletions.
94 changes: 80 additions & 14 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
Expand Down Expand Up @@ -54,7 +55,12 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
}

size_t InputSize(const std::string& name) const override {
return ctx_.Inputs(name).size();
if (ctx_.HasInputs(name)) {
return ctx_.Inputs(name).size();
} else if (ctx_.HasInput(name)) {
return 1;
}
return 0;
}

size_t OutputSize(const std::string& name) const override {
Expand Down Expand Up @@ -288,6 +294,16 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
auto& attr_names = std::get<1>(signature.args);
auto& output_names = std::get<2>(signature.args);

auto kernels_map =
phi::KernelFactory::Instance().SelectKernelMap(signature.name);
if (kernels_map.size() == 0) {
PADDLE_THROW(
platform::errors::Unimplemented("Not find `%s` kernels when construct "
"InferMetaContext.",
signature.name));
}
auto attr_defs = kernels_map.cbegin()->second.args_def().attribute_defs();

// TODO(chenweihang): support multiple inputs and outputs later
phi::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) {
Expand All @@ -299,9 +315,70 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
}

for (auto& out_name : output_names) {
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}
auto attr_reader = ctx->Attrs();
for (auto& attr_name : attr_names) {
if (ctx->HasAttr(attr_name)) {
for (size_t i = 0; i < attr_names.size(); ++i) {
auto attr_name = attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) {
// When attr is a vector_tensor or tensor, transform it to ScalarArray
if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
if (ctx->IsRuntime()) {
// If is in runtime, we will get tensor's value for ScalarArray
// and push it into attrs
std::vector<Variable*> vars;
vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) {
vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i]));
}
if (infershape_inputs.size() != 1) {
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarArrayFromVarList(vars)));
} else {
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarArrayFromVar(*vars[0])));
}
} else {
// If is not in runtime, we will set default value(-1) for ScalarArray
int64_t num_ele = 1;
std::vector<VarDesc*> vars;
vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) {
vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i]));
}
for (auto& var : vars) {
const auto& tensor_dims = var->GetShape();
for (size_t i = 0; i < tensor_dims.size(); ++i) {
num_ele *= tensor_dims[i];
}
}
phi::ScalarArray tensor_attr(std::vector<int32_t>(num_ele, -1));
tensor_attr.SetFromTensor(true);
infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
}
} else if (ctx->HasAttr(attr_name)) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
infer_meta_context.EmplaceBackAttr(std::move(
phi::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
"construct KernelContext.",
attr_name));
}
}

} else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
Expand Down Expand Up @@ -345,17 +422,6 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
"Unsupported attribute type is received when call "
"InferShapeFunctor."));
}
} else {
// do nothing
}
}

for (auto& out_name : output_names) {
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}

Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/infershape_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/kernel_registry.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -93,6 +96,17 @@ phi::KernelSignature InferShapeUtilsTestOpArgumentMapping(
{});
}

template <typename T, typename Context>
void InferShapeUtilsTestKernel(
const Context& dev_ctx, const phi::DenseTensor& x, bool attr1, int attr2,
int64_t attr3, float attr4, const std::string& attr5,
const std::vector<bool>& attr6, const std::vector<int>& attr7,
const std::vector<int64_t>& attr8, const std::vector<float>& attr9,
const std::vector<double>& attr10, const std::vector<std::string>& attr11,
phi::DenseTensor* out) {
VLOG(6) << "Come into InferShapeUtilsTestKernel";
}

} // namespace framework
} // namespace paddle

Expand All @@ -104,6 +118,9 @@ REGISTER_OPERATOR(infer_shape_utils_test,
paddle::framework::InferShapeUtilsTestOpMaker,
InferShapeUtilsTestInferShapeFunctor);

PT_REGISTER_KERNEL(infer_shape_utils_test, CPU, ALL_LAYOUT,
paddle::framework::InferShapeUtilsTestKernel, int) {}

TEST(InferShapeUtilsTest, ALL) {
paddle::framework::ProgramDesc prog;
paddle::framework::proto::BlockDesc proto_block;
Expand Down
26 changes: 9 additions & 17 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@ limitations under the License. */

#include <string>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"

// only can include the headers in paddle/phi/api dirs
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/reshape_grad_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"

namespace paddle {
namespace framework {
class InferShapeContext;
Expand Down Expand Up @@ -472,22 +476,6 @@ class Reshape2Op : public ReshapeOp {
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ReshapeOp(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
platform::errors::InvalidArgument(
"Output(XShape) of ReshapeOp should not be null."));
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");

ReshapeOp::InferShape(ctx);
}
};

class Reshape2OpMaker : public ReshapeOpMaker {
Expand Down Expand Up @@ -647,10 +635,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);

DELCARE_INFER_SHAPE_FUNCTOR(reshape2, ReshapeInferShapeFunctor,
PT_INFER_META(phi::ReshapeWithXShapeInferMeta));

REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker<paddle::framework::OpDesc>,
ops::Reshape2GradMaker<paddle::imperative::OpBase>,
ops::ReshapeOpInplaceInferer);
ReshapeInferShapeFunctor, ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/lib/utils/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ phi::ScalarArray MakePtenScalarArrayFromVarList(
}

phi::ScalarArray result{vector_data};
result.setInitByTensor(true);
result.SetFromTensor(true);

return result;
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/common/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace experimental {
template <typename T>
class ScalarBase {
public:
bool IsInitByTensor() const { return is_init_by_tensor_; }
bool FromTensor() const { return is_from_tensor_; }
// Constructor support implicit
ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT
data_.f64 = val;
Expand Down Expand Up @@ -104,7 +104,7 @@ class ScalarBase {

// The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
is_init_by_tensor_ = true;
is_from_tensor_ = true;
PD_CHECK(
tensor.numel() == 1,
"The Scalar only supports Tensor with 1 element, but now Tensor has `",
Expand Down Expand Up @@ -196,7 +196,7 @@ class ScalarBase {
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);

private:
bool is_init_by_tensor_{false};
bool is_from_tensor_{false};
DataType dtype_;
union data {
bool b;
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/common/scalar_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ class ScalarArrayBase {
AssignData(date_value, n);
}

bool IsInitByTensor() const { return is_init_by_tensor_; }
bool FromTensor() const { return is_from_tensor_; }

void setInitByTensor(bool val) { is_init_by_tensor_ = val; }
void SetFromTensor(bool val) { is_from_tensor_ = val; }

// The Tensor must have one dim
ScalarArrayBase(const T& tensor) { // NOLINT
is_init_by_tensor_ = true;
is_from_tensor_ = true;
size_t n = tensor.numel();
array_.reserve(n);
switch (tensor.dtype()) {
Expand All @@ -71,7 +71,7 @@ class ScalarArrayBase {

// The Tensor in vec must have only one element
ScalarArrayBase(const std::vector<T>& tensor_list) { // NOLINT
is_init_by_tensor_ = true;
is_from_tensor_ = true;

for (size_t i = 0; i < tensor_list.size(); ++i) {
DataType data_type = tensor_list[i].dtype();
Expand Down Expand Up @@ -117,7 +117,7 @@ class ScalarArrayBase {
// TODO(zhangyunfei) Replace std::vector with a more efficient container
// structure.
std::vector<int64_t> array_;
bool is_init_by_tensor_{false};
bool is_from_tensor_{false};
};

using ScalarArray =
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<std::string>&);

/* Output Helpers */

Expand Down
42 changes: 38 additions & 4 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/phi/infermeta/unary.h"

#include <set>

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"

namespace phi {
Expand Down Expand Up @@ -217,7 +217,7 @@ void InferMetaFromVecValue(const MetaTensor& x,
MetaTensor* out) {
PADDLE_ENFORCE_EQ(!shape.empty(),
true,
paddle::platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The parameter 'shape' in ReshapeOp must be set. "
"But received 'shape' is empty."));
auto x_dims = x.dims();
Expand All @@ -234,8 +234,42 @@ void InferMetaFromVecValue(const MetaTensor& x,

void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* out) {
InferMetaFromVecValue(x, shape.GetData(), out);
MetaTensor* out,
MetaConfig config) {
auto& shape_data = shape.GetData();
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"Output(Out) of ReshapeOp should not be null."));
if (!config.is_runtime && shape.FromTensor()) {
out->set_dims(phi::make_ddim(shape_data));
out->share_lod(x);
return;
}
PADDLE_ENFORCE_GT(shape_data.size(),
0,
phi::errors::InvalidArgument(
"The shape's size in ReshapeOp can't be zero."));
InferMetaFromVecValue(x, shape_data, out);
}

void ReshapeWithXShapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* xshape,
MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL(
xshape,
phi::errors::InvalidArgument(
"Output(XShape) of ReshapeOp should not be null."));
const auto& x_dims = x.dims();
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
ReshapeInferMeta(x, shape, out, config);
}

/* Why not use ReduceInferMeta directly?
Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ void InferMetaFromVecValue(const MetaTensor& x,

void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* out);
MetaTensor* out,
MetaConfig config = MetaConfig());

void ReshapeWithXShapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* xshape,
MetaTensor* out,
MetaConfig config = MetaConfig());

void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/split_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void SplitKernel(const Context& dev_ctx,
const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) {
// need to infershape output
if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) {
if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) {
std::vector<MetaTensor> out_metas;
for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]);
Expand Down
Loading

0 comments on commit 45dd4a5

Please sign in to comment.