From c73b6cea0d50226447c506c07929e788eaf1ba22 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 27 Feb 2024 08:22:53 +0000 Subject: [PATCH] [PIR] Align reshape infermeta to legacy IR infershape --- paddle/phi/api/yaml/ops.yaml | 2 +- paddle/phi/infermeta/unary.cc | 21 +++++++++++++++++++-- paddle/phi/infermeta/unary.h | 2 +- test/legacy_test/test_reshape_op.py | 14 ++++++++++---- 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 6605c214a97d4..2ee266e1af398 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1696,7 +1696,7 @@ args : (Tensor x) output : Tensor(out) infer_meta : - func : LogicalNotInfermeta + func : LogicalNotInferMeta kernel : func : logical_not data_type : x diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 42eaa2670a0b5..5648ff0d469a3 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2188,7 +2188,7 @@ void KthvalueInferMeta(const MetaTensor& x, indices->set_dtype(x.dtype()); } -void LogicalNotInfermeta(const MetaTensor& x, MetaTensor* out) { +void LogicalNotInferMeta(const MetaTensor& x, MetaTensor* out) { UnchangedInferMeta(x, out); if (!(out->is_same_tensor(x))) { out->set_dtype(DataType::BOOL); @@ -3588,11 +3588,28 @@ void ReshapeInferMeta(const MetaTensor& x, const IntArray& shape, MetaTensor* out, MetaConfig config) { - auto& shape_data = shape.GetData(); + auto shape_data = shape.GetData(); PADDLE_ENFORCE_NOT_NULL(out, phi::errors::InvalidArgument( "Output(Out) of ReshapeOp should not be null.")); if (!config.is_runtime && shape.FromTensor()) { + const int64_t copy_dim_flag = 0; + const auto& in_dims = x.dims(); + for (size_t i = 0; i < shape_data.size(); ++i) { + if (shape_data[i] == copy_dim_flag) { + PADDLE_ENFORCE_LT( + static_cast(i), + in_dims.size(), + phi::errors::InvalidArgument( + "The index of 0 in `shape` must be less than " + "the input tensor X's dimensions. But received shape[%d] " + "= 0, X's dimensions = %d, X's shape = [%s].", + i, + in_dims.size(), + in_dims)); + shape_data[i] = static_cast(in_dims[static_cast(i)]); + } + } out->set_dims(common::make_ddim(shape_data)); out->share_lod(x); out->set_dtype(x.dtype()); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a9f5f2eb1a13c..d62789bd5183c 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -328,7 +328,7 @@ void KthvalueInferMeta(const MetaTensor& x, MetaTensor* indices, MetaConfig = MetaConfig()); -void LogicalNotInfermeta(const MetaTensor& x, MetaTensor* out); +void LogicalNotInferMeta(const MetaTensor& x, MetaTensor* out); void LogsumexpInferMeta(const MetaTensor& input, const std::vector& axis, diff --git a/test/legacy_test/test_reshape_op.py b/test/legacy_test/test_reshape_op.py index c786562062554..964f487a4ed9d 100755 --- a/test/legacy_test/test_reshape_op.py +++ b/test/legacy_test/test_reshape_op.py @@ -737,15 +737,21 @@ def test_static(self): class TestReshapePirValueListShape(unittest.TestCase): def test_value_list_shape(self): with paddle.pir_utils.IrGuard(): - x = paddle.static.data( - 'x', - [3], - ) + x = paddle.static.data('x', [3]) shape = [1, paddle.full([], 3)] out = paddle.reshape(x, shape) self.assertEqual(out.shape, [1, -1]) +class TestReshapePirTensorWithZeroShape(unittest.TestCase): + def test_tensor_with_zero_shape(self): + with paddle.pir_utils.IrGuard(): + x = paddle.static.data('x', [10, -1]) + shape = [0, paddle.shape(x)[1]] + out = paddle.reshape(x, shape) + self.assertEqual(out.shape, [10, -1]) + + if __name__ == "__main__": paddle.enable_static() unittest.main()