From d20e27a9723eba8f6162c9f33f10287252bd2d02 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Wed, 10 Jan 2024 13:52:47 +0000 Subject: [PATCH 1/2] replace generate_shape_op related corner case handling code with a more general function DoGenerateShapeOpGroupFustion --- .../group_with_group_merge_pass.cc | 42 +++++++++++++++++++ .../group_merge/op_with_group_merge_pass.cc | 7 +--- paddle/fluid/pybind/pir.cc | 4 +- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc index 1c28039718a74..85ae32b8a337e 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc @@ -29,6 +29,8 @@ #include "paddle/cinn/common/is_reachable_predicator.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" + PD_DECLARE_bool(enhance_vertical_fusion_with_recompute); namespace cinn { @@ -1095,6 +1097,46 @@ class GeneralFusionMergePassHelper { while (DoGeneralRecomputeAndVerticalFusion()) { } + DoGenerateShapeOpGroupFustion(); + } + + void DoGenerateShapeOpGroupFustion() { + VLOG(3) << "DoFuseGenerateShapeOp...!"; + bool updated = false; + for (size_t idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer idx " << idx << " Group -> " + << producer->group_id; + // Skip to next iterator if producer is sub group. + if (producer->belong_groups.size()) continue; + + // generate-shape-op group has only one op. + if (producer->ops.size() != 1) continue; + + // only generate-shape-op groups will be processed. + if (!producer->ops.at(0)->isa()) continue; + + // generate-shape-op groups have no inputs. + if (!producer->input_ops.empty()) continue; + + // do generate-shape-op fusion. + updated |= FuseGenerateShapeOpGroupToConsumer(producer); + } + if (updated) { + UpdateFusionGroup(); + } + } + + bool FuseGenerateShapeOpGroupToConsumer(const GroupPtr& producer) { + VLOG(3) << "FuseGenerateShapeOpGroupToConsumer handling producer : " + << producer->group_id; + // copy is need. + auto consumer_groups = producer->consumer_groups(); + if (consumer_groups.size() > 0) { + RecomputeFuse(producer, consumer_groups); + return true; + } + return false; } bool DoGeneralHorizontalFusion() { diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc index f5050e51e197b..54005eb22f25b 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc @@ -387,8 +387,7 @@ class OpFusionPassHelper { // kNonFusible op can't fuse any other op. auto producer_kind = hlir::framework::pir::CompatibleInfo::OpKind(*producer); - if (producer_kind == OpPatternKind::kNonFusible && - producer->name() != "cinn_op.generate_shape") { + if (producer_kind == OpPatternKind::kNonFusible) { continue; } // VLOG(3) << "Producer Op: " << producer->id() @@ -430,8 +429,7 @@ class OpFusionPassHelper { consumer_fusion->input_ops.erase(producer); consumer_fusion->op_pattern_kind = static_cast(consumer_fusion->op_pattern_kind) > - static_cast(producer_kind) || - producer->name() == "cinn_op.generate_shape" + static_cast(producer_kind) ? consumer_fusion->op_pattern_kind : producer_kind; @@ -595,7 +593,6 @@ class OpFusionPassHelper { } bool CanFuse(::pir::Operation* producer, const ::pir::Operation* consumer) { - if (producer->name() == "cinn_op.generate_shape") return true; auto& relation = fusion_relation_map_[hlir::framework::pir::CompatibleInfo::OpKind( *producer)]; diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 02f61625a3dc0..cfbf5e9ce02a4 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -1605,7 +1605,6 @@ void AddCinnPass(std::shared_ptr &pass_manager, // NOLINT std::make_unique< cinn::dialect::ir::FuseShapeOpsIntoGenerateShapeOpPass>()); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); - // pass_manager->AddPass(pir::CreateShapeOptimizationPass()); } cinn::dialect::ir::PdOp2CinnOpConverter(&program); @@ -1615,8 +1614,7 @@ void AddCinnPass(std::shared_ptr &pass_manager, // NOLINT pass_manager->AddPass(pir::CreateBuildCinnPass()); if (has_dynamic_shape) { - // pass_manager->AddPass(cinn::dialect::ir::CreateRewriteGenerateShapeOpsToRunFirstPass()); - // pass_manager->AddPass(pir::CreateShapeOptimizationPass()); + pass_manager->AddPass(cinn::dialect::ir::CreateRewriteGenerateShapeOpsToRunFirstPass()); } pass_manager->AddPass( From df477bc3a4a12f59ac6db1d08aa12baf0d2a54dd Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Wed, 10 Jan 2024 14:37:27 +0000 Subject: [PATCH 2/2] rename XXXRunFirstXXX to XXXPrologueXXX --- .../transforms/group_merge/CMakeLists.txt | 2 +- .../group_merge/group_with_group_merge_pass.cc | 18 +++++++++--------- ...ove_generate_shape_ops_to_prologue_pass.cc} | 14 +++++++------- ...move_generate_shape_ops_to_prologue_pass.h} | 2 +- paddle/fluid/pybind/pir.cc | 4 ++-- 5 files changed, 20 insertions(+), 20 deletions(-) rename paddle/cinn/hlir/dialect/operator/transforms/group_merge/{rewrite_generate_shape_ops_to_run_first_pass.cc => move_generate_shape_ops_to_prologue_pass.cc} (88%) rename paddle/cinn/hlir/dialect/operator/transforms/group_merge/{rewrite_generate_shape_ops_to_run_first_pass.h => move_generate_shape_ops_to_prologue_pass.h} (91%) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/CMakeLists.txt index 0166a4a652507..d30cc58436f21 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/CMakeLists.txt @@ -5,7 +5,7 @@ if(NOT CINN_ONLY) group_with_group_merge_pass.cc op_with_group_merge_pass.cc cinn_group_lowering_pass.cc - rewrite_generate_shape_ops_to_run_first_pass.cc + move_generate_shape_ops_to_prologue_pass.cc generate_shape_util.cc tensor_node.cc DEPS diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc index 85ae32b8a337e..73f4f7aa45096 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc @@ -1097,11 +1097,11 @@ class GeneralFusionMergePassHelper { while (DoGeneralRecomputeAndVerticalFusion()) { } - DoGenerateShapeOpGroupFustion(); + DoPrologueGenerateShapeOpGroupFustion(); } - void DoGenerateShapeOpGroupFustion() { - VLOG(3) << "DoFuseGenerateShapeOp...!"; + void DoPrologueGenerateShapeOpGroupFustion() { + VLOG(3) << "DoPrologueGenerateShapeOpGroupFustion...!"; bool updated = false; for (size_t idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; @@ -1110,25 +1110,25 @@ class GeneralFusionMergePassHelper { // Skip to next iterator if producer is sub group. if (producer->belong_groups.size()) continue; - // generate-shape-op group has only one op. + // a prologue generate-shape-op group has only one op. if (producer->ops.size() != 1) continue; - // only generate-shape-op groups will be processed. + // only prologue generate-shape-op groups will be processed. if (!producer->ops.at(0)->isa()) continue; - // generate-shape-op groups have no inputs. + // prologue generate-shape-op groups have no inputs. if (!producer->input_ops.empty()) continue; // do generate-shape-op fusion. - updated |= FuseGenerateShapeOpGroupToConsumer(producer); + updated |= FusePrologueGenerateShapeOpGroupToConsumer(producer); } if (updated) { UpdateFusionGroup(); } } - bool FuseGenerateShapeOpGroupToConsumer(const GroupPtr& producer) { - VLOG(3) << "FuseGenerateShapeOpGroupToConsumer handling producer : " + bool FusePrologueGenerateShapeOpGroupToConsumer(const GroupPtr& producer) { + VLOG(3) << "FusePrologueGenerateShapeOpGroupToConsumer handling producer : " << producer->group_id; // copy is need. auto consumer_groups = producer->consumer_groups(); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/rewrite_generate_shape_ops_to_run_first_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc similarity index 88% rename from paddle/cinn/hlir/dialect/operator/transforms/group_merge/rewrite_generate_shape_ops_to_run_first_pass.cc rename to paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc index 49ac7b5e61de9..e9d1a2ffe56ef 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/rewrite_generate_shape_ops_to_run_first_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc @@ -34,7 +34,7 @@ #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.h" -#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/rewrite_generate_shape_ops_to_run_first_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.h" #include "paddle/pir/dialect/shape/utils/shape_utils.h" namespace cinn { @@ -66,10 +66,10 @@ class GroupOpGenerateShapeOpsPattern : public pir::OpRewritePatternisa() && op->num_regions() > 0)) return false; auto* program = op->GetParentProgram(); - VLOG(4) << "Before RewriteGenerateShapeOpsToRunFirstPass: " << *program; + VLOG(4) << "Before MoveGenerateShapeOpsToProloguePass: " << *program; return true; } @@ -90,8 +90,8 @@ class RewriteGenerateShapeOpsToRunFirstPass : public pir::PatternRewritePass { namespace ir { -std::unique_ptr<::pir::Pass> CreateRewriteGenerateShapeOpsToRunFirstPass() { - return std::make_unique(); +std::unique_ptr<::pir::Pass> CreateMoveGenerateShapeOpsToProloguePass() { + return std::make_unique(); } } // namespace ir diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/rewrite_generate_shape_ops_to_run_first_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.h similarity index 91% rename from paddle/cinn/hlir/dialect/operator/transforms/group_merge/rewrite_generate_shape_ops_to_run_first_pass.h rename to paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.h index 6d0c6a8e017eb..65afa519cf50c 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/rewrite_generate_shape_ops_to_run_first_pass.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.h @@ -21,7 +21,7 @@ namespace cinn { namespace dialect { namespace ir { -std::unique_ptr<::pir::Pass> CreateRewriteGenerateShapeOpsToRunFirstPass(); +std::unique_ptr<::pir::Pass> CreateMoveGenerateShapeOpsToProloguePass(); } // namespace ir } // namespace dialect diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index cfbf5e9ce02a4..762ac6bc569e9 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -86,7 +86,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.h" -#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/rewrite_generate_shape_ops_to_run_first_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h" #include "paddle/cinn/hlir/framework/pir_compiler.h" @@ -1614,7 +1614,7 @@ void AddCinnPass(std::shared_ptr &pass_manager, // NOLINT pass_manager->AddPass(pir::CreateBuildCinnPass()); if (has_dynamic_shape) { - pass_manager->AddPass(cinn::dialect::ir::CreateRewriteGenerateShapeOpsToRunFirstPass()); + pass_manager->AddPass(cinn::dialect::ir::CreateMoveGenerateShapeOpsToProloguePass()); } pass_manager->AddPass(