Skip to content

Commit

Permalink
Merge pull request #10 from jiahy0825/broadcast_dynamic
Browse files Browse the repository at this point in the history
replace generate_shape_op related corner case handling code with a mo…
  • Loading branch information
zyfncg authored Jan 10, 2024
2 parents 7566a35 + df477bc commit 7327ebe
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ if(NOT CINN_ONLY)
group_with_group_merge_pass.cc
op_with_group_merge_pass.cc
cinn_group_lowering_pass.cc
rewrite_generate_shape_ops_to_run_first_pass.cc
move_generate_shape_ops_to_prologue_pass.cc
generate_shape_util.cc
tensor_node.cc
DEPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include "paddle/cinn/common/is_reachable_predicator.h"

#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"

PD_DECLARE_bool(enhance_vertical_fusion_with_recompute);

namespace cinn {
Expand Down Expand Up @@ -1095,6 +1097,46 @@ class GeneralFusionMergePassHelper {

while (DoGeneralRecomputeAndVerticalFusion()) {
}
DoPrologueGenerateShapeOpGroupFustion();
}

void DoPrologueGenerateShapeOpGroupFustion() {
VLOG(3) << "DoPrologueGenerateShapeOpGroupFustion...!";
bool updated = false;
for (size_t idx = 0; idx < fusion_groups_.size(); ++idx) {
auto producer = fusion_groups_[idx];
VLOG(3) << "Fusion Producer idx " << idx << " Group -> "
<< producer->group_id;
// Skip to next iterator if producer is sub group.
if (producer->belong_groups.size()) continue;

// a prologue generate-shape-op group has only one op.
if (producer->ops.size() != 1) continue;

// only prologue generate-shape-op groups will be processed.
if (!producer->ops.at(0)->isa<cinn::dialect::GenerateShapeOp>()) continue;

// prologue generate-shape-op groups have no inputs.
if (!producer->input_ops.empty()) continue;

// do generate-shape-op fusion.
updated |= FusePrologueGenerateShapeOpGroupToConsumer(producer);
}
if (updated) {
UpdateFusionGroup();
}
}

bool FusePrologueGenerateShapeOpGroupToConsumer(const GroupPtr& producer) {
VLOG(3) << "FusePrologueGenerateShapeOpGroupToConsumer handling producer : "
<< producer->group_id;
// copy is need.
auto consumer_groups = producer->consumer_groups();
if (consumer_groups.size() > 0) {
RecomputeFuse(producer, consumer_groups);
return true;
}
return false;
}

bool DoGeneralHorizontalFusion() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/rewrite_generate_shape_ops_to_run_first_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.h"
#include "paddle/pir/dialect/shape/utils/shape_utils.h"

namespace cinn {
Expand Down Expand Up @@ -66,10 +66,10 @@ class GroupOpGenerateShapeOpsPattern : public pir::OpRewritePattern<cinn::dialec

};

class RewriteGenerateShapeOpsToRunFirstPass : public pir::PatternRewritePass {
class MoveGenerateShapeOpsToProloguePass : public pir::PatternRewritePass {
public:
RewriteGenerateShapeOpsToRunFirstPass()
: pir::PatternRewritePass("generate_shape_ops_to_run_first", 1) {}
MoveGenerateShapeOpsToProloguePass()
: pir::PatternRewritePass("move_generate_shape_ops_to_prologue", 1) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
Expand All @@ -80,7 +80,7 @@ class RewriteGenerateShapeOpsToRunFirstPass : public pir::PatternRewritePass {
bool CanApplyOn(pir::Operation* op) const override {
if (!(op->isa<pir::ModuleOp>() && op->num_regions() > 0)) return false;
auto* program = op->GetParentProgram();
VLOG(4) << "Before RewriteGenerateShapeOpsToRunFirstPass: " << *program;
VLOG(4) << "Before MoveGenerateShapeOpsToProloguePass: " << *program;
return true;
}

Expand All @@ -90,8 +90,8 @@ class RewriteGenerateShapeOpsToRunFirstPass : public pir::PatternRewritePass {

namespace ir {

std::unique_ptr<::pir::Pass> CreateRewriteGenerateShapeOpsToRunFirstPass() {
return std::make_unique<RewriteGenerateShapeOpsToRunFirstPass>();
std::unique_ptr<::pir::Pass> CreateMoveGenerateShapeOpsToProloguePass() {
return std::make_unique<MoveGenerateShapeOpsToProloguePass>();
}

} // namespace ir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace cinn {
namespace dialect {
namespace ir {

std::unique_ptr<::pir::Pass> CreateRewriteGenerateShapeOpsToRunFirstPass();
std::unique_ptr<::pir::Pass> CreateMoveGenerateShapeOpsToProloguePass();

} // namespace ir
} // namespace dialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,7 @@ class OpFusionPassHelper {
// kNonFusible op can't fuse any other op.
auto producer_kind =
hlir::framework::pir::CompatibleInfo::OpKind(*producer);
if (producer_kind == OpPatternKind::kNonFusible &&
producer->name() != "cinn_op.generate_shape") {
if (producer_kind == OpPatternKind::kNonFusible) {
continue;
}
// VLOG(3) << "Producer Op: " << producer->id()
Expand Down Expand Up @@ -430,8 +429,7 @@ class OpFusionPassHelper {
consumer_fusion->input_ops.erase(producer);
consumer_fusion->op_pattern_kind =
static_cast<int>(consumer_fusion->op_pattern_kind) >
static_cast<int>(producer_kind) ||
producer->name() == "cinn_op.generate_shape"
static_cast<int>(producer_kind)
? consumer_fusion->op_pattern_kind
: producer_kind;

Expand Down Expand Up @@ -595,7 +593,6 @@ class OpFusionPassHelper {
}

bool CanFuse(::pir::Operation* producer, const ::pir::Operation* consumer) {
if (producer->name() == "cinn_op.generate_shape") return true;
auto& relation =
fusion_relation_map_[hlir::framework::pir::CompatibleInfo::OpKind(
*producer)];
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/rewrite_generate_shape_ops_to_run_first_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
#include "paddle/cinn/hlir/framework/pir_compiler.h"
Expand Down Expand Up @@ -1605,7 +1605,6 @@ void AddCinnPass(std::shared_ptr<PassManager> &pass_manager, // NOLINT
std::make_unique<
cinn::dialect::ir::FuseShapeOpsIntoGenerateShapeOpPass>());
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
// pass_manager->AddPass(pir::CreateShapeOptimizationPass());
}
cinn::dialect::ir::PdOp2CinnOpConverter(&program);

Expand All @@ -1615,8 +1614,7 @@ void AddCinnPass(std::shared_ptr<PassManager> &pass_manager, // NOLINT
pass_manager->AddPass(pir::CreateBuildCinnPass());

if (has_dynamic_shape) {
// pass_manager->AddPass(cinn::dialect::ir::CreateRewriteGenerateShapeOpsToRunFirstPass());
// pass_manager->AddPass(pir::CreateShapeOptimizationPass());
pass_manager->AddPass(cinn::dialect::ir::CreateMoveGenerateShapeOpsToProloguePass());
}

pass_manager->AddPass(
Expand Down

0 comments on commit 7327ebe

Please sign in to comment.