From 93a01bbfd959f477964f67f5e7accaa878705c1e Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Sun, 15 Sep 2024 22:51:28 +0800 Subject: [PATCH] Add fused gemm infer symbolic (#68198) * update * update * fix compile bug * fix bug * add reserve space and constrain * revert fused gemm infer meta --- .../multiary_infer_sym.cc | 55 +++++++++++++++++++ .../infer_symbolic_shape/multiary_infer_sym.h | 1 + .../pir/dialect/operator/ir/manual_op.cc | 6 ++ .../fluid/pir/dialect/operator/ir/manual_op.h | 9 +-- 4 files changed, 67 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index b5bee94f6a0f6..4f6b03577a206 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -1563,6 +1563,61 @@ bool FusedBnAddActivation_OpInferSymbolicShape( return BatchNormOpInferSymbolicShape(op, infer_context); } +bool FusedGemmEpilogueOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &x_dims = x_shape_or_data.shape(); + + const auto &y_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const auto &y_dims = y_shape_or_data.shape(); + + const auto &bias_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const auto &bias_dims = bias_shape_or_data.shape(); + + size_t x_rank = x_dims.size(); + size_t y_rank = y_dims.size(); + + std::vector out_shape; + out_shape.reserve(x_rank); + + for (size_t i = 0; i + 2 < x_rank; ++i) { + out_shape.emplace_back(x_dims[i]); + } + + bool transpose_x_attr = GetBoolAttr(op, "trans_x"); + bool transpose_y_attr = GetBoolAttr(op, "trans_y"); + + symbol::DimExpr out_M = + transpose_x_attr ? x_dims[x_rank - 1] : x_dims[x_rank - 2]; + symbol::DimExpr out_N = + transpose_y_attr ? y_dims[y_rank - 2] : y_dims[y_rank - 1]; + + out_shape.emplace_back(out_M); + out_shape.emplace_back(out_N); + + symbol::DimExpr x_K = + transpose_x_attr ? x_dims[x_rank - 2] : x_dims[x_rank - 1]; + symbol::DimExpr y_K = + transpose_y_attr ? y_dims[y_rank - 1] : y_dims[y_rank - 2]; + + infer_context->AddEqualCstr(x_K, y_K); + // bias_dims[0] equal to out_N + infer_context->AddEqualCstr(out_N, bias_dims[0]); + + infer_context->SetShapeOrDataForValue(op->result(0), + ShapeOrData{TensorExprs(out_shape)}); + + // process reserve space + if (paddle::dialect::details::IsFakeValue(op->result(1))) { + infer_context->SetShapeOrDataForValue(op->result(0), + ShapeOrData{TensorExprs(out_shape)}); + } + return true; +} + bool FusedMultiTransformerOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 1e7babac4a164..b4fff39faff7a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -61,6 +61,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBatchNormAct) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBatchNormAct_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedGemmEpilogue) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedMultiTransformer) OP_DECLARE_INFER_SYMBOLIC_SHAPE(GenerateProposals) OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphKhopSampler) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index da354e03bd209..984e27c387459 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -888,6 +888,12 @@ std::vector FusedGemmEpilogueOp::InferMeta( return argument_outputs; } +bool FusedGemmEpilogueOp::InferSymbolicShape( + pir::InferSymbolicShapeContext *infer_context) { + return FusedGemmEpilogueOpInferSymbolicShape(this->operation(), + infer_context); +} + const char *FusedGemmEpilogueGradOp::attributes_name[3] = { // NOLINT "trans_x", "trans_y", diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index ba72cfeab52cb..6b31b7a7994c0 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -84,10 +84,10 @@ class AddNArrayOp : public pir::Op { +class FusedGemmEpilogueOp : public pir::Op { public: using Op::Op; static const char *name() { return "pd_op.fused_gemm_epilogue"; } @@ -112,6 +112,7 @@ class FusedGemmEpilogueOp static std::vector InferMeta( const std::vector &input_values, pir::AttributeMap *p_attributes); + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; class TEST_API FusedGemmEpilogueGradOp