Skip to content

Commit

Permalink
[Prim][Pir] Fix the bug in test_sub_graph73 when open prim setting (#…
Browse files Browse the repository at this point in the history
…63276)

* fix the bug in test_sub_graph_73

* fix typo

* fix the bug in test_sub_graph_73 when set with_prim=True
  • Loading branch information
zeroRains authored Apr 8, 2024
1 parent 045d416 commit a76e7ed
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
39 changes: 24 additions & 15 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1448,38 +1448,47 @@ void ScatterNdAddInferMeta(const MetaTensor& x,

// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
std::vector<int64_t> r_updates_dims;
bool without_dynamic_shape = true;
for (int i = 0; i < index_dims_size - 1; ++i) {
if (index_dims[i] == -1) {
without_dynamic_shape = false;
}
r_updates_dims.emplace_back(index_dims[i]);
}
for (int i = static_cast<int>(index_dims[index_dims_size - 1]);
i < ref_dims_size;
++i) {
if (ref_dims[i] == -1) {
without_dynamic_shape = false;
}
r_updates_dims.emplace_back(ref_dims[i]);
}
// check for non-0d updates
PADDLE_ENFORCE_EQ(
r_updates_dims.size(),
updates_dims_size,
phi::errors::InvalidArgument(
"Updates has wrong shape. The shape of Updates and Input(Updates) "
"Updates has wrong shape. The shape of Updates and "
"Input(Updates) "
"should be same, but received the shape of Updates is %d, "
"the shape of Input(Updates) is %d.",
r_updates_dims.size(),
updates_dims_size));

for (int64_t i = 0; i < updates_dims_size; ++i) {
PADDLE_ENFORCE_EQ(
r_updates_dims[i],
updates_dims[i],
phi::errors::InvalidArgument(
"Updates has wrong shape. The dimensions of Updates and "
"Input(Updates) should match, but received Updates's"
"%d-th dimension is %d, Input(Updates)'s %d-th "
"dimension is %d.",
i,
r_updates_dims[i],
i,
updates_dims[i]));
if (without_dynamic_shape) {
for (int64_t i = 0; i < updates_dims_size; ++i) {
PADDLE_ENFORCE_EQ(
r_updates_dims[i],
updates_dims[i],
phi::errors::InvalidArgument(
"Updates has wrong shape. The dimensions of Updates and "
"Input(Updates) should match, but received Updates's"
"%d-th dimension is %d, Input(Updates)'s %d-th "
"dimension is %d.",
i,
r_updates_dims[i],
i,
updates_dims[i]));
}
}
}
out->set_dims(ref_dims);
Expand Down
2 changes: 1 addition & 1 deletion test/ir/pir/cinn/sub_graphs/test_sub_graph_73.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def train(self, net, to_static, with_prim=False, with_cinn=False):
def test_ast_prim_cinn(self):
st_out = self.train(self.net, to_static=True)
cinn_out = self.train(
self.net, to_static=True, with_prim=False, with_cinn=False
self.net, to_static=True, with_prim=True, with_cinn=False
)
for st, cinn in zip(
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
Expand Down

0 comments on commit a76e7ed

Please sign in to comment.