From ae7cbca176d25fb3401dedffa342b7f432eec287 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 29 Jul 2024 10:12:19 +0800 Subject: [PATCH] fix bug of infer symbol shape about tensor_list (#66662) --- .../infer_symbolic_shape/unary_infer_sym.cc | 30 ++++--------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 79512fb69ea95..010d11ace051e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -925,17 +925,9 @@ bool SplitOpInferSymbolicShape(pir::Operation *op, axis = axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank)); // sections - const std::vector §ions_sym = [&] { - const auto §ions_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(1)); - std::vector sections_sym; - if (sections_shape_or_data.data().has_value()) { - sections_sym = sections_shape_or_data.data().value(); - } else { - sections_sym = sections_shape_or_data.shape(); - } - return sections_sym; - }(); + const std::vector §ions_sym = + details::GetExprVecFromData( + infer_context->GetShapeOrDataForValue(op->operand_source(1))); // output const symbol::TensorListShapeOrDataDimExprs &output_shape_data_list = [&] { @@ -1123,19 +1115,9 @@ bool TileOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs repeat_times_shape_or_data = infer_context->GetShapeOrDataForValue(operand_repeat_times); - std::vector x_dimexpr; - if (x_shape_or_data.data().has_value()) { - x_dimexpr = x_shape_or_data.data().value(); - } else { - x_dimexpr = x_shape_or_data.shape(); - } - - std::vector repeat_times_dimexpr; - if (repeat_times_shape_or_data.data().has_value()) { - repeat_times_dimexpr = repeat_times_shape_or_data.data().value(); - } else { - repeat_times_dimexpr = repeat_times_shape_or_data.shape(); - } + std::vector x_dimexpr = x_shape_or_data.shape(); + std::vector repeat_times_dimexpr = + details::GetExprVecFromData(repeat_times_shape_or_data); if (repeat_times_dimexpr.empty()) { repeat_times_dimexpr = std::vector(x_dimexpr.size(), 1); }