Skip to content

Commit

Permalink
fix bug of infer symbol shape about tensor_list (PaddlePaddle#66662)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg authored and inaomIIsfarell committed Jul 31, 2024
1 parent 1c28620 commit ae7cbca
Showing 1 changed file with 6 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -925,17 +925,9 @@ bool SplitOpInferSymbolicShape(pir::Operation *op,
axis = axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank));

// sections
const std::vector<symbol::DimExpr> &sections_sym = [&] {
const auto &sections_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
std::vector<symbol::DimExpr> sections_sym;
if (sections_shape_or_data.data().has_value()) {
sections_sym = sections_shape_or_data.data().value();
} else {
sections_sym = sections_shape_or_data.shape();
}
return sections_sym;
}();
const std::vector<symbol::DimExpr> &sections_sym =
details::GetExprVecFromData(
infer_context->GetShapeOrDataForValue(op->operand_source(1)));

// output
const symbol::TensorListShapeOrDataDimExprs &output_shape_data_list = [&] {
Expand Down Expand Up @@ -1123,19 +1115,9 @@ bool TileOpInferSymbolicShape(pir::Operation *op,
symbol::ShapeOrDataDimExprs repeat_times_shape_or_data =
infer_context->GetShapeOrDataForValue(operand_repeat_times);

std::vector<symbol::DimExpr> x_dimexpr;
if (x_shape_or_data.data().has_value()) {
x_dimexpr = x_shape_or_data.data().value();
} else {
x_dimexpr = x_shape_or_data.shape();
}

std::vector<symbol::DimExpr> repeat_times_dimexpr;
if (repeat_times_shape_or_data.data().has_value()) {
repeat_times_dimexpr = repeat_times_shape_or_data.data().value();
} else {
repeat_times_dimexpr = repeat_times_shape_or_data.shape();
}
std::vector<symbol::DimExpr> x_dimexpr = x_shape_or_data.shape();
std::vector<symbol::DimExpr> repeat_times_dimexpr =
details::GetExprVecFromData(repeat_times_shape_or_data);
if (repeat_times_dimexpr.empty()) {
repeat_times_dimexpr = std::vector<symbol::DimExpr>(x_dimexpr.size(), 1);
}
Expand Down

0 comments on commit ae7cbca

Please sign in to comment.