Skip to content

Commit

Permalink
[PIR+CINN]Fix parallel compilation value symbolic update (#63305)
Browse files Browse the repository at this point in the history
* [PIR+CINN]Fix parallel compilation value symbolic update

* fix typo

* del usless file

* fix comment

* fix typo
  • Loading branch information
Aurelius84 authored Apr 9, 2024
1 parent 6adee41 commit 7c5e803
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ pir::Operation* ProcessDyShapeGroup(
const OpLoweringGroupPtr& group,
pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
pir::PatternRewriter& rewriter) { // NOLINT
// NOTE(dev): Need UpdateShapeOrDataExprs firstly and the logic
// will be migated into BucketLower later.
UpdateGroupShapeOrDataExprs(const_cast<OpLoweringGroupPtr&>(group));
auto group_inputs = GetBlockOutsideInput(group->ops());
GroupDimExprInfo group_dim_expr_info = GetGroupDimExprInfo(group);
const auto& leaves = group_dim_expr_info.all_value_dim_exprs;
Expand Down
49 changes: 28 additions & 21 deletions paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,33 +88,36 @@ std::unordered_map<std::string, ::pir::Attribute> GetJitKernelAttr(

OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr) {
auto fusion_op = fusion_op_ptr->dyn_cast<cinn::dialect::FusionOp>();
auto group = std::make_shared<OpLoweringGroup>();
group->set_op_pattern_kind(
cinn::hlir::framework::OpPatternKind::kElementWise);
std::vector<::pir::Operation*> ops;
auto group_op_kind = cinn::hlir::framework::OpPatternKind::kElementWise;
// Rebuild ops of the group
for (auto op : fusion_op.GetOperators()) {
if (!op->isa<::pir::YieldOp>()) {
ops.push_back(op);
group_op_kind = static_cast<int>(CompatibleInfo::OpKind(*op)) >
static_cast<int>(group_op_kind)
? CompatibleInfo::OpKind(*op)
: group_op_kind;
}
}

auto group = std::make_shared<OpLoweringGroup>(ops);

if (fusion_op.attributes().count("group_info")) {
auto attr = fusion_op.attribute("group_info")
.dyn_cast<cinn::dialect::GroupInfoAttribute>()
.data();

group->set_op_pattern_kind(attr.op_pattern_kind);
group_op_kind =
static_cast<int>(attr.op_pattern_kind) > static_cast<int>(group_op_kind)
? attr.op_pattern_kind
: group_op_kind;
group->set_loop_ranges(attr.loop_ranges);
group->set_loop_ranges_expr(attr.loop_ranges_expr);

group->set_reduce_axis(attr.reduce_axis);
group->set_alignment_schedule_info(attr.alignment_schedule_info);
}

// Rebuild ops of the group
for (auto op : fusion_op.GetOperators()) {
if (!op->isa<::pir::YieldOp>()) {
group->mut_ops().push_back(op);
auto op_pattern_kind = static_cast<int>(CompatibleInfo::OpKind(*op)) >
static_cast<int>(group->op_pattern_kind())
? CompatibleInfo::OpKind(*op)
: group->op_pattern_kind();
group->set_op_pattern_kind(op_pattern_kind);
}
}
group->set_op_pattern_kind(group_op_kind);

// Rebuild output_ops and input_ops of the group
auto yield_op = fusion_op.GetOperators().back();
Expand All @@ -127,10 +130,7 @@ OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr) {
// Because the group is rebuilt, the order of group.output_values generated
// by BuildCUDAJITInfo may not be same with the order bound in the yield op,
// so a mapping is required.
auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(fusion_op->GetParentProgram());
group->set_value_to_shape_or_data_exprs(
CreateGroupShapeOrDataExprs(group, shape_analysis));
UpdateGroupShapeOrDataExprs(group);
if (FLAGS_cinn_enable_map_expr) {
cinn::adt::TryGenerateMapExprFromGroup(group);
}
Expand All @@ -139,4 +139,11 @@ OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr) {
return group;
}

void UpdateGroupShapeOrDataExprs(OpLoweringGroupPtr group) {
auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(group->GetParentProgram());
group->set_value_to_shape_or_data_exprs(
CreateGroupShapeOrDataExprs(group, shape_analysis));
}

} // namespace cinn::dialect::ir::details
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ std::unordered_map<std::string, ::pir::Attribute> GetJitKernelAttr(

OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr);

void UpdateGroupShapeOrDataExprs(OpLoweringGroupPtr group);

} // namespace cinn::dialect::ir::details
11 changes: 11 additions & 0 deletions paddle/cinn/hlir/framework/pir/op_lowering_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ namespace hlir {
namespace framework {
namespace pir {

::pir::Program* OpLoweringGroup::GetParentProgram() const {
PADDLE_ENFORCE_GT(ops_.size(),
0,
::common::errors::PreconditionNotMet(
"Require at least one op in the group."));
PADDLE_ENFORCE_NOT_NULL(
ops_[0],
::common::errors::Unavailable("Found group.ops_[0] is nullptr."));
return ops_[0]->GetParentProgram();
}

std::shared_ptr<OpLoweringGroup> OpLoweringGroup::Clone(
::pir::Block* target_block, ::pir::IrMapping* ir_mapping) const {
std::vector<::pir::Operation*> new_ops;
Expand Down
59 changes: 28 additions & 31 deletions paddle/cinn/hlir/framework/pir/op_lowering_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/common/enforce.h"
#include "paddle/pir/include/core/builtin_type_interfaces.h"
#include "paddle/pir/include/core/operation.h"
#include "paddle/pir/include/core/value.h"
Expand All @@ -38,15 +39,18 @@ namespace framework {
namespace pir {
class OpLoweringGroup {
public:
OpLoweringGroup() = default;
OpLoweringGroup(const OpLoweringGroup&) = delete;
OpLoweringGroup(OpLoweringGroup&&) = delete;

explicit OpLoweringGroup(const std::vector<::pir::Operation*>& group_ops)
: ops_(group_ops) {}
: ops_(group_ops) {
fn_name_ = CompatibleInfo::GroupOpsName(ops_);
}

explicit OpLoweringGroup(std::initializer_list<::pir::Operation*> group_ops)
: ops_(group_ops) {}
: ops_(group_ops) {
fn_name_ = CompatibleInfo::GroupOpsName(ops_);
}

struct SharedGroupHasher {
size_t operator()(
Expand Down Expand Up @@ -88,27 +92,18 @@ class OpLoweringGroup {

std::unordered_set<::pir::Value> GetInputOpValues() const {
std::unordered_set<::pir::Value> group_inputs;

std::unordered_set<::pir::Operation*> ops_set;
for (auto op : this->ops_) {
ops_set.insert(op);
}
std::unordered_set<::pir::Operation*> ops_set(this->ops_.begin(),
this->ops_.end());

// count all op's input Value
for (auto op : this->ops_) {
for (auto op : ops_set) {
for (auto& value : op->operands_source()) {
if (!value || !value.type()) {
continue;
}

if (!ops_set.count(value.defining_op())) {
// if the input value owner op is not in OpSet, it's the group's input
group_inputs.insert(value);
if (!value || !value.type() || ops_set.count(value.defining_op()))
continue;
}
// if the input value owner op is not in OpSet, it's the group's input
group_inputs.insert(value);
}
}

return group_inputs;
}

Expand All @@ -127,19 +122,14 @@ class OpLoweringGroup {
return group_outputs;
}

std::string FuncName() const {
if (fn_name_ == "") {
// TODO(Aurelius84): Polish this implementation.
const_cast<OpLoweringGroup*>(this)->fn_name_ =
CompatibleInfo::GroupOpsName(ops_);
}
return this->fn_name_;
}
const std::string& FuncName() const { return fn_name_; }

const symbol::ShapeOrDataDimExprs& GetShapeOrDataExprs(
const ::pir::Value& value) const {
CHECK(value_to_shape_or_data_exprs_.count(value))
<< "value not found in value_to_shape_or_data_exprs_";
PADDLE_ENFORCE_EQ(HasShapeOrDataExprs(value),
true,
::common::errors::NotFound(
"value not found in value_to_shape_or_data_exprs_"));
return value_to_shape_or_data_exprs_.at(value);
}

Expand Down Expand Up @@ -198,14 +188,21 @@ class OpLoweringGroup {
}

std::shared_ptr<adt::MapExprCtx> mut_map_expr_ctx() {
CHECK_NOTNULL(map_expr_ctx_);
PADDLE_ENFORCE_NOT_NULL(
map_expr_ctx_,
::common::errors::Unavailable("Required map_expr_ctx_ != nullptr."));
return map_expr_ctx_;
}

const adt::MapExprCtx& map_expr_ctx() const {
return *CHECK_NOTNULL(map_expr_ctx_);
PADDLE_ENFORCE_NOT_NULL(
map_expr_ctx_,
::common::errors::Unavailable("Required map_expr_ctx_ != nullptr."));
return *map_expr_ctx_;
}

::pir::Program* GetParentProgram() const;

void set_value_to_shape_or_data_exprs(
const std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs>&
value_to_shape_or_data_exprs) {
Expand Down Expand Up @@ -285,6 +282,7 @@ class OpLoweringGroup {
std::string group_id_{common::UniqName("group_")};
// op in this group
std::vector<::pir::Operation*> ops_;
std::string fn_name_;
// output ops of the group.
std::unordered_set<::pir::Operation*> output_ops_;
// op pattern kind.
Expand All @@ -293,7 +291,6 @@ class OpLoweringGroup {
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
std::vector<::pir::Value> output_values_;
std::string fn_name_{""};
std::map<int, CINNKernelInfo::ArgDimIdx> int_args_map_;

alignment_schedule_info_t alignment_schedule_info_;
Expand Down
8 changes: 4 additions & 4 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,12 @@ std::string CompatibleInfo::OpFuncName(const ::pir::Operation& op) {

std::string CompatibleInfo::GroupOpsName(
const std::vector<::pir::Operation*>& ops) {
std::string name = "fn";
std::string name = "fn_";
for (auto* op : ops) {
std::string op_name = OpName(*op);
name += "_" + cinn::common::Context::Global().NewName(op_name);
name += OpName(*op);
name += "_";
}
return name;
return cinn::common::Context::Global().NewName(name);
}

std::string CompatibleInfo::ValueName(const ::pir::Value& value) {
Expand Down

0 comments on commit 7c5e803

Please sign in to comment.