Skip to content

Commit

Permalink
[CINN]Unify compilation of broadcast group and other groups (#66768)
Browse files Browse the repository at this point in the history
* [CINN]Unify compilation of broadcast group and other groups

* fix
  • Loading branch information
Hongqing-work authored Jul 31, 2024
1 parent ce02a77 commit 2be8b7e
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,17 @@
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h"
#include "paddle/cinn/common/broadcast_tree.h"
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"
#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h"

using OpLoweringGroup = cinn::hlir::framework::pir::OpLoweringGroup;
using OpLoweringGroupPtr = std::shared_ptr<OpLoweringGroup>;
using BroadcastCond = std::pair<symbol::Broadcastable<symbol::DimExpr>,
OpLoweringGroup::BranchType>;
using cinn::dialect::ir::details::CompileBroadcastGroupsAsOpAttribute;
using cinn::dialect::ir::details::GetBlockOutsideInput;

PD_DECLARE_bool(cinn_bc_branch_optimize);

namespace {
std::vector<pir::Value> GetOpOuputValues(const pir::Operation* op) {
std::vector<pir::Value> outputs;
outputs.reserve(op->num_results());
for (size_t i = 0; i < op->num_results(); ++i) {
outputs.push_back(op->result(i));
}
return outputs;
}

using ShapeOrDataDimExprs4ValueT =
std::function<const symbol::ShapeOrDataDimExprs&(pir::Value)>;

static bool SameInputOutputShape(
paddle::dialect::ExpandOp expand_op,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) {
const auto& x = ShapeOrDataDimExprs4Value(expand_op.x());
const auto& shape = ShapeOrDataDimExprs4Value(expand_op.shape());
const auto& out = ShapeOrDataDimExprs4Value(expand_op.out());
if (x.data().has_value()) return false;
if (!shape.data().has_value()) return false;
if (out.data().has_value()) return false;
CHECK(shape.data().value() == out.shape());
return x.shape() == out.shape();
}

void UpdateGroupShapeExprs(
const OpLoweringGroupPtr& new_group,
const OpLoweringGroupPtr& origin_group,
Expand All @@ -83,8 +50,8 @@ void UpdateGroupShapeExprs(
}

OpLoweringGroupPtr CloneGroup(const OpLoweringGroupPtr& group,
const int& group_idx) {
return group->Clone(group_idx);
const std::string& name_suffix) {
return group->Clone(name_suffix);
}

void SetBroadcastLeafGroup(
Expand All @@ -93,7 +60,8 @@ void SetBroadcastLeafGroup(
const std::unordered_map<pir::Value, size_t>& value_to_dim_expr_idx,
std::vector<OpLoweringGroupPtr>* group_list,
const std::vector<BroadcastCond>& broadcast_conditions) {
auto new_group = CloneGroup(origin_group, group_list->size() + 1);
auto new_group =
CloneGroup(origin_group, std::to_string(group_list->size() + 1));
new_group->SetIsBroadcastLeaf(true);
new_group->SetBroadcastConditions(broadcast_conditions);
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -163,26 +131,11 @@ void ConstructBroadcastGroupList(
current_branch_conditions->pop_back();
}
}
} // namespace

namespace cinn::dialect::ir::details {

std::optional<std::shared_ptr<BroadcastTree>> ConstructBroadcastTree(
const cinn::common::BroadcastLeaf& leaves) {
VLOG(6) << "before constructed. broadcast-leaf: \n"
<< ToTxtString(cinn::common::BroadcastTree(leaves));
int num_of_leaves = 0;
auto broadcast_tree = std::make_shared<cinn::common::BroadcastTree>(
cinn::common::ConstructBroadcastTree(cinn::common::BroadcastLeaf(leaves),
&num_of_leaves));
if (num_of_leaves > FLAGS_pir_broadcast_tree_limit) {
LOG(WARNING) << "the number of leaf nodes in broadcast tree exceeds "
"limit.";
return std::nullopt;
}
VLOG(4) << "broadcast-tree: \n" << ToTxtString(*broadcast_tree);
return broadcast_tree;
}
struct GroupDimExprInfo {
cinn::common::BroadcastLeaf all_value_dim_exprs;
std::unordered_map<pir::Value, size_t> value_to_dim_expr_idx;
};

GroupDimExprInfo GetGroupDimExprInfo(const OpLoweringGroupPtr& group) {
std::unordered_set<pir::Value> value_view;
Expand All @@ -196,7 +149,7 @@ GroupDimExprInfo GetGroupDimExprInfo(const OpLoweringGroupPtr& group) {
});

GroupDimExprInfo group_dim_expr_info;
for (auto value : value_view) {
for (const auto& value : value_view) {
const auto& shape_dim_expr = group->GetShapeOrDataExprs(value);
const auto& data_shape = shape_dim_expr.data();
if (data_shape) {
Expand All @@ -211,50 +164,62 @@ GroupDimExprInfo GetGroupDimExprInfo(const OpLoweringGroupPtr& group) {
return group_dim_expr_info;
}

std::optional<std::shared_ptr<BroadcastTree>> GetBroadcastTreeForOptimize(
bool ContainBroadcastShape(const cinn::common::BroadcastLeaf& leaves) {
std::optional<symbol::Broadcastable<symbol::DimExpr>>
broadcastable_condition = cinn::common::GetFirstCstrBroadcastable(leaves);
return broadcastable_condition.has_value();
}

std::optional<std::shared_ptr<cinn::common::BroadcastTree>>
ConstructBroadcastTree(const cinn::common::BroadcastLeaf& leaves) {
VLOG(6) << "before constructed. broadcast-leaf: \n"
<< ToTxtString(cinn::common::BroadcastTree(leaves));
int num_of_leaves = 0;
auto broadcast_tree = std::make_shared<cinn::common::BroadcastTree>(
cinn::common::ConstructBroadcastTree(cinn::common::BroadcastLeaf(leaves),
&num_of_leaves));
if (num_of_leaves > FLAGS_pir_broadcast_tree_limit) {
LOG(WARNING) << "the number of leaf nodes in broadcast tree exceeds "
"limit.";
return std::nullopt;
}
VLOG(4) << "broadcast-tree: \n" << ToTxtString(*broadcast_tree);
return broadcast_tree;
}

} // namespace

namespace cinn::dialect::ir::details {
std::optional<std::vector<OpLoweringGroupPtr>> GetBroadcastGroupListForOptimize(
const OpLoweringGroupPtr& group) {
if (!FLAGS_cinn_bc_branch_optimize) return std::nullopt;

const common::BroadcastLeaf leaves = [&]() {
// NOTE(dev): Need UpdateShapeOrDataExprs firstly and the logic
// will be migated into BucketLower later.
UpdateGroupShapeOrDataExprs(const_cast<OpLoweringGroupPtr&>(group));
GroupDimExprInfo group_dim_expr_info = GetGroupDimExprInfo(group);
return group_dim_expr_info.all_value_dim_exprs;
}();
UpdateGroupShapeOrDataExprs(const_cast<OpLoweringGroupPtr&>(group));
GroupDimExprInfo group_dim_expr_info = GetGroupDimExprInfo(group);
if (!ContainBroadcastShape(group_dim_expr_info.all_value_dim_exprs))
return std::nullopt;

if (!ContainBroadcastShape(leaves)) return std::nullopt;
const auto& optional_broadcast_tree =
ConstructBroadcastTree(group_dim_expr_info.all_value_dim_exprs);

return ConstructBroadcastTree(leaves);
}
if (!optional_broadcast_tree.has_value()) return std::nullopt;

bool ContainBroadcastShape(const cinn::common::BroadcastLeaf& leaves) {
std::optional<symbol::Broadcastable<symbol::DimExpr>>
broadcastable_condition = cinn::common::GetFirstCstrBroadcastable(leaves);
return broadcastable_condition.has_value();
}
const auto& broadcast_tree = optional_broadcast_tree.value();

std::unordered_map<std::string, pir::Attribute> CompileBroadcastTree(
const OpLoweringGroupPtr& group,
const BroadcastTree& broadcast_tree,
const std::unordered_map<pir::Value, size_t>& value_to_dim_expr_idx) {
auto ShapeOrDataDimExprs4Value =
[&group](pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
return group->GetShapeOrDataExprs(value);
const auto& ChangeBroadcastTreeToGroupList =
[&]() -> std::vector<OpLoweringGroupPtr> {
std::vector<OpLoweringGroupPtr> group_list;
std::vector<BroadcastCond> current_branch_conditions;
const auto& value_to_dim_expr_idx =
group_dim_expr_info.value_to_dim_expr_idx;
ConstructBroadcastGroupList(*broadcast_tree,
group,
&current_branch_conditions,
value_to_dim_expr_idx,
&group_list);
return group_list;
};
// 1. broadcast tree to condition op
VLOG(4) << "broadcast tree to condition op";
std::vector<OpLoweringGroupPtr> group_list;
std::vector<BroadcastCond> current_branch_conditions;
ConstructBroadcastGroupList(broadcast_tree,
group,
&current_branch_conditions,
value_to_dim_expr_idx,
&group_list);

// 2. compile condition block to jit_kernel_op
auto op_attr = CompileBroadcastGroupsAsOpAttribute(group_list, group);

return op_attr;
return ChangeBroadcastTreeToGroupList();
}
} // namespace cinn::dialect::ir::details
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,12 @@
// limitations under the License.

#pragma once
#include "paddle/cinn/common/broadcast_tree.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"
#include <optional>
#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h"

using OpLoweringGroup = cinn::hlir::framework::pir::OpLoweringGroup;
using OpLoweringGroupPtr = std::shared_ptr<OpLoweringGroup>;
namespace cinn::dialect::ir::details {
using cinn::common::BroadcastTree;

class BroadcastTreeInfo;

struct GroupDimExprInfo {
common::BroadcastLeaf all_value_dim_exprs;
std::unordered_map<pir::Value, size_t> value_to_dim_expr_idx;
};

std::optional<std::shared_ptr<BroadcastTree>> ConstructBroadcastTree(
const common::BroadcastLeaf& leaves);

std::optional<std::shared_ptr<BroadcastTree>> GetBroadcastTreeForOptimize(
std::optional<std::vector<OpLoweringGroupPtr>> GetBroadcastGroupListForOptimize(
const OpLoweringGroupPtr& group);
bool ContainBroadcastShape(const common::BroadcastLeaf& leaves);
GroupDimExprInfo GetGroupDimExprInfo(const OpLoweringGroupPtr& group);

std::unordered_map<std::string, pir::Attribute> CompileBroadcastTree(
const OpLoweringGroupPtr& group,
const BroadcastTree& broadcast_tree,
const std::unordered_map<pir::Value, size_t>& value_to_dim_expr_idx);
} // namespace cinn::dialect::ir::details
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/refresh_combine_pattern.h"
Expand All @@ -29,34 +28,6 @@
#include "paddle/pir/include/pass/pass_registry.h"

namespace cinn::dialect::ir::details {

pir::Operation* ProcessDyShapeGroup(const OpLoweringGroupPtr& group,
pir::PatternRewriter& rewriter) { // NOLINT
std::unordered_map<std::string, ::pir::Attribute> jit_kernel_attr = [&]() {
const auto& optional_broadcast_tree = GetBroadcastTreeForOptimize(group);
if (optional_broadcast_tree.has_value()) {
const std::shared_ptr<BroadcastTree> broadcast_tree =
optional_broadcast_tree.value();
const auto& value_to_dim_expr_idx =
GetGroupDimExprInfo(group).value_to_dim_expr_idx;
return CompileBroadcastTree(
group, *broadcast_tree, value_to_dim_expr_idx);
} else {
return GetJitKernelAttr(group);
}
}();

// compile group to jit_kernel_op
const auto& group_inputs = GetBlockOutsideInput(group->ops());
std::vector<pir::Type> output_types;
for (const auto& value : group->output_values()) {
output_types.push_back(value.type());
}
auto jit_kernel_op = rewriter.Build<cinn::dialect::JitKernelOp>(
group_inputs, jit_kernel_attr, output_types);
return jit_kernel_op;
}

class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
public:
FusionOpPattern(::pir::IrContext* context, const GroupInfoMap& group_infos)
Expand Down Expand Up @@ -136,18 +107,6 @@ class LowerCinnFusionOpPass : public pir::PatternRewritePass {
mutable GroupInfoMap group_infos_;
};

class DyShapeFusionOpPattern : public FusionOpPattern {
public:
using FusionOpPattern::FusionOpPattern;

protected:
virtual pir::Operation* ProcessGroup(
const OpLoweringGroupPtr& group,
pir::PatternRewriter& rewriter) const { // NOLINT
return ProcessDyShapeGroup(group, rewriter);
}
};

class LowerCinnDyShapeFusionOpPass : public pir::PatternRewritePass {
public:
LowerCinnDyShapeFusionOpPass()
Expand All @@ -158,7 +117,7 @@ class LowerCinnDyShapeFusionOpPass : public pir::PatternRewritePass {
context->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();

pir::RewritePatternSet ps(context);
ps.Add<DyShapeFusionOpPattern>(context, group_infos_);
ps.Add<FusionOpPattern>(context, group_infos_);
ps.Add<RefreshCombineOpPattern>(context);

return ps;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"
#include "paddle/cinn/hlir/framework/pir_compiler.h"
#include "paddle/common/flags.h"
Expand Down Expand Up @@ -51,9 +52,19 @@ void FusionOpAnalysis::PreCompileGroup() {

std::vector<OpLoweringGroupPtr> groups;
for (auto& group_info : *group_infos_) {
if (is_dy_shape_ &&
GetBroadcastTreeForOptimize(group_info.second).has_value())
continue;
if (is_dy_shape_) {
const auto& optional_broadcast_group_list =
GetBroadcastGroupListForOptimize(group_info.second);
if (optional_broadcast_group_list.has_value()) {
std::vector<OpLoweringGroupPtr> group_list =
optional_broadcast_group_list.value();
VLOG(4) << "Pre-Compile for Broadcast Group with size: "
<< group_list.size();
PirCompiler pir_compiler(cinn::common::DefaultDeviceTarget());
pir_compiler.BuildBroadcastTree(group_list, group_info.second);
continue;
}
}
groups.push_back(group_info.second);
}
// Build and trigger compilaion cache.
Expand Down
Loading

0 comments on commit 2be8b7e

Please sign in to comment.