From 1bb53ddfabab9d8228529bad2040d81250fb1058 Mon Sep 17 00:00:00 2001 From: Zero Rains Date: Mon, 8 Apr 2024 10:40:39 +0800 Subject: [PATCH] [Prim][Pir] Fix the bug in test_sub_graph73 when open prim setting (#63276) * fix the bug in test_sub_graph_73 * fix typo * fix the bug in test_sub_graph_73 when set with_prim=True --- paddle/phi/infermeta/ternary.cc | 39 ++++++++++++------- .../pir/cinn/sub_graphs/test_sub_graph_73.py | 2 +- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index c58f53c664e54a..beba7457039cc0 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1448,12 +1448,19 @@ void ScatterNdAddInferMeta(const MetaTensor& x, // update.shape = index.shape[:-1] + output.shape[index.shape[-1]:] std::vector r_updates_dims; + bool without_dynamic_shape = true; for (int i = 0; i < index_dims_size - 1; ++i) { + if (index_dims[i] == -1) { + without_dynamic_shape = false; + } r_updates_dims.emplace_back(index_dims[i]); } for (int i = static_cast(index_dims[index_dims_size - 1]); i < ref_dims_size; ++i) { + if (ref_dims[i] == -1) { + without_dynamic_shape = false; + } r_updates_dims.emplace_back(ref_dims[i]); } // check for non-0d updates @@ -1461,25 +1468,27 @@ void ScatterNdAddInferMeta(const MetaTensor& x, r_updates_dims.size(), updates_dims_size, phi::errors::InvalidArgument( - "Updates has wrong shape. The shape of Updates and Input(Updates) " + "Updates has wrong shape. The shape of Updates and " + "Input(Updates) " "should be same, but received the shape of Updates is %d, " "the shape of Input(Updates) is %d.", r_updates_dims.size(), updates_dims_size)); - - for (int64_t i = 0; i < updates_dims_size; ++i) { - PADDLE_ENFORCE_EQ( - r_updates_dims[i], - updates_dims[i], - phi::errors::InvalidArgument( - "Updates has wrong shape. The dimensions of Updates and " - "Input(Updates) should match, but received Updates's" - "%d-th dimension is %d, Input(Updates)'s %d-th " - "dimension is %d.", - i, - r_updates_dims[i], - i, - updates_dims[i])); + if (without_dynamic_shape) { + for (int64_t i = 0; i < updates_dims_size; ++i) { + PADDLE_ENFORCE_EQ( + r_updates_dims[i], + updates_dims[i], + phi::errors::InvalidArgument( + "Updates has wrong shape. The dimensions of Updates and " + "Input(Updates) should match, but received Updates's" + "%d-th dimension is %d, Input(Updates)'s %d-th " + "dimension is %d.", + i, + r_updates_dims[i], + i, + updates_dims[i])); + } } } out->set_dims(ref_dims); diff --git a/test/ir/pir/cinn/sub_graphs/test_sub_graph_73.py b/test/ir/pir/cinn/sub_graphs/test_sub_graph_73.py index 1fa302916dacdc..c344c87a5591a6 100644 --- a/test/ir/pir/cinn/sub_graphs/test_sub_graph_73.py +++ b/test/ir/pir/cinn/sub_graphs/test_sub_graph_73.py @@ -100,7 +100,7 @@ def train(self, net, to_static, with_prim=False, with_cinn=False): def test_ast_prim_cinn(self): st_out = self.train(self.net, to_static=True) cinn_out = self.train( - self.net, to_static=True, with_prim=False, with_cinn=False + self.net, to_static=True, with_prim=True, with_cinn=False ) for st, cinn in zip( paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)