From b63595ba800188e7d06fcd74e8c32f271bc35881 Mon Sep 17 00:00:00 2001 From: Hongqing-work Date: Fri, 10 May 2024 06:58:04 +0000 Subject: [PATCH] [CINN]fix InferSymbolicShape for builtin.slice and pd_op.stack --- .../infer_symbolic_shape/multiary_infer_sym.cc | 5 ++++- paddle/fluid/pir/dialect/operator/ir/op_dialect.cc | 14 ++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 4a8b0c1000cc0..0e3ed5e1d9c4a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -589,7 +589,10 @@ bool StackOpInferSymbolicShape(pir::Operation *op, shape_dim_exprs.insert(shape_dim_exprs.begin() + axis, static_cast(shape_data_list.size())); } - + if (data_dim_exprs.empty()) { + return symbol::ShapeOrDataDimExprs( + symbol::TensorShapeOrDataDimExprs(shape_dim_exprs)); + } return symbol::ShapeOrDataDimExprs( symbol::TensorShapeOrDataDimExprs(shape_dim_exprs, data_dim_exprs)); }(); diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 5eea152adbdb1..3b1521e4787bd 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -170,12 +170,14 @@ struct SliceOpInferSymbolicShapeInterfaceModel pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { const auto index = op->attributes().at("index").dyn_cast().data(); - const auto output_value = - (op->operand(0).type().dyn_cast())[index] - .dyn_cast(); - - infer_context->SetShapeOrDataForValue( - op->result(0), infer_context->GetShapeOrDataForValue(output_value)); + const auto& input_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + CHECK(input_shape.isa()); + const symbol::TensorListShapeOrDataDimExprs& data_shape_list = + input_shape.dyn_cast(); + const symbol::TensorShapeOrDataDimExprs& output_shape = + data_shape_list[index]; + infer_context->SetShapeOrDataForValue(op->result(0), output_shape); return true; }