Skip to content

Commit

Permalink
[CINN+PIR]Support DoGroupSchedule for PIRCompiler (PaddlePaddle#58399)
Browse files Browse the repository at this point in the history
* [CINN+PIR]Support DoGroupSchedule for PIRComppiler

fix complation problem

* fix conflict

* using output_ops to parse function arguments

* fix unittest

* remove VLOG(1)

* ignore some UT and add FIXME
  • Loading branch information
Aurelius84 authored Nov 1, 2023
1 parent 17207d1 commit cf67c51
Show file tree
Hide file tree
Showing 20 changed files with 2,515 additions and 574 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program) {
auto ir_compiler =
new cinn::hlir::framework::PIRCompiler(*program, target, scope);
auto group1 =
std::make_shared<cinn::hlir::framework::pir::Group>(group->nodes);
std::make_shared<cinn::hlir::framework::pir::Group>(group->ops);
auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group1});
compiler_list.push_back(ir_compiler);
std::unordered_map<std::string, ::pir::Attribute> op_attrs{
Expand All @@ -133,14 +133,14 @@ std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program) {
};

// Generate jit kernel op input and output
auto vec_ins = GetBlockOutsideInput(group->nodes);
auto vec_ins = GetBlockOutsideInput(group->ops);

std::vector<pir::Value> vec_new_ins;
for (size_t i = 0; i < vec_ins.size(); ++i) {
vec_new_ins.push_back(value_map.at(vec_ins[i]));
}

auto vec_outs = GetBlockOutsideOutput(group->nodes);
auto vec_outs = GetBlockOutsideOutput(group->ops);

std::vector<pir::Type> vec_types;
for (auto& out : vec_outs) {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ static bool ReduceFuseReduce(const OpGroupPtr& first,
}
std::unique_ptr<cinn::dialect::ir::OpNode> reducer_0 = nullptr;
first.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) {
if (!reducer_0 && op.kind() == kReduction) {
if (!reducer_0 && op.kind() == OpPatternKind::kReduction) {
reducer_0.reset(new cinn::dialect::ir::OpNode(op));
}
});
CHECK(reducer_0) << "Can't find reduce op in group " << first.group_id();

std::unique_ptr<cinn::dialect::ir::OpNode> reducer_1 = nullptr;
second.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) {
if (!reducer_1 && op.kind() == kReduction) {
if (!reducer_1 && op.kind() == OpPatternKind::kReduction) {
reducer_1.reset(new cinn::dialect::ir::OpNode(op));
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ inline bool limit_args(const std::shared_ptr<ir::Group>& first,
const std::shared_ptr<ir::Group>& second) {
std::unordered_set<const ::pir::Operation*> args;
for (auto& group : {first, second}) {
for (auto node : group->input_nodes) {
args.insert(node.first);
for (auto iter : group->input_ops) {
args.insert(iter.first);
}
for (auto node : group->output_nodes) {
args.insert(node);
for (auto op : group->output_ops) {
args.insert(op);
}
}

Expand All @@ -66,8 +66,8 @@ inline bool is_same_shape(const std::shared_ptr<ir::Group>& first,
return false;
}

auto output_var_0 = GetValueShape((*first->master_nodes.begin())->result(0));
auto output_var_1 = GetValueShape((*second->master_nodes.begin())->result(0));
auto output_var_0 = GetValueShape((*first->master_ops.begin())->result(0));
auto output_var_1 = GetValueShape((*second->master_ops.begin())->result(0));
return output_var_0 == output_var_1;
}

Expand All @@ -77,8 +77,8 @@ inline bool is_same_size(const std::shared_ptr<ir::Group>& first,
return false;
}

auto output_var_0 = GetValueShape((*first->master_nodes.begin())->result(0));
auto output_var_1 = GetValueShape((*second->master_nodes.begin())->result(0));
auto output_var_0 = GetValueShape((*first->master_ops.begin())->result(0));
auto output_var_1 = GetValueShape((*second->master_ops.begin())->result(0));
if (output_var_0 == output_var_1) {
return true;
}
Expand All @@ -89,8 +89,8 @@ inline bool is_same_size(const std::shared_ptr<ir::Group>& first,
}

inline bool is_const_group(const std::shared_ptr<ir::Group>& group) {
return group->CollectNodes().size() == 1 &&
ConstantOps.count(group->CollectNodes()[0]->name());
return group->CollectOps().size() == 1 &&
ConstantOps.count(group->CollectOps()[0]->name());
}

inline bool elementwise_fuse_broadcast(
Expand All @@ -105,9 +105,9 @@ inline bool elementwise_fuse_broadcast(
return true;
}
// if first's output is not all in second's input
for (auto output : first->output_nodes) {
for (auto output : first->output_ops) {
return true;
if (!second->input_nodes.count(output)) {
if (!second->input_ops.count(output)) {
return false;
}

Expand All @@ -130,7 +130,7 @@ inline bool honrizontal_elementwise_fuse_reduce(
const std::shared_ptr<ir::Group>& first,
const std::shared_ptr<ir::Group>& second) {
std::shared_ptr<ir::Group> ele_group, reduce_group;
if (first->op_pattern_kind == kReduction) {
if (first->op_pattern_kind == OpPatternKind::kReduction) {
ele_group = second;
reduce_group = first;
} else {
Expand All @@ -143,10 +143,10 @@ inline bool honrizontal_elementwise_fuse_reduce(
}

auto ele_node_shape =
GetValueShape((*ele_group->master_nodes.begin())->result(0));
GetValueShape((*ele_group->master_ops.begin())->result(0));
int32_t size_ele = phi::product(ele_node_shape);
// TODO(phlrain): seems extrame danger herem, why compare multi Master Node?
for (auto* master : reduce_group->master_nodes) {
for (auto* master : reduce_group->master_ops) {
auto master_node_shape = GetValueShape(master->result(0));
int32_t size_master = phi::product(master_node_shape);
if (size_ele == size_master) {
Expand All @@ -169,9 +169,9 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr<ir::Group>& first,

// if reduce nodes not in consumers of first group
std::queue<::pir::Operation*> candidates;
std::unordered_set<::pir::Operation*> first_node_set = first->NodeSet();
std::unordered_set<::pir::Operation*> second_node_set = second->NodeSet();
for (const auto& pair : second->input_nodes) {
std::unordered_set<::pir::Operation*> first_node_set = first->OpSet();
std::unordered_set<::pir::Operation*> second_node_set = second->OpSet();
for (const auto& pair : second->input_ops) {
if (first_node_set.find(pair.first) != first_node_set.end()) {
candidates.push(pair.first);
}
Expand All @@ -195,15 +195,15 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr<ir::Group>& first,
visited.insert(consumer);
candidates.push(consumer);
}
if (second->master_nodes.count(consumer)) {
if (second->master_ops.count(consumer)) {
masters_in_consumers.insert(consumer);
}
}
}
if (!masters_in_consumers.empty()) {
bool flag = true;
auto first_node_shape =
GetValueShape((*first->master_nodes.begin())->result(0));
GetValueShape((*first->master_ops.begin())->result(0));
int32_t size_first = phi::product(first_node_shape);

for (::pir::Operation* master : masters_in_consumers) {
Expand All @@ -221,8 +221,8 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr<ir::Group>& first,

// if reduce using block_reduce, can't fuse producer.
::pir::Operation* reducer = nullptr;
for (auto& node : second->master_nodes) {
if (GetOpKind(node->name()) == kReduction) {
for (auto& node : second->master_ops) {
if (GetOpKind(node->name()) == OpPatternKind::kReduction) {
reducer = node;
break;
}
Expand Down Expand Up @@ -289,8 +289,8 @@ inline bool broadcast_fuse_reduce(const std::shared_ptr<ir::Group>& first,
return true;
}
::pir::Operation* reducer = nullptr;
for (auto& node : second->master_nodes) {
if (GetOpKind(node->name()) == kReduction) {
for (auto& node : second->master_ops) {
if (GetOpKind(node->name()) == OpPatternKind::kReduction) {
reducer = node;
break;
}
Expand All @@ -300,7 +300,7 @@ inline bool broadcast_fuse_reduce(const std::shared_ptr<ir::Group>& first,
auto input_shape = GetValueShape(reducer->operand_source(0));
auto input_size = phi::product(input_shape);

auto output_shape = GetValueShape((*first->master_nodes.begin())->result(0));
auto output_shape = GetValueShape((*first->master_ops.begin())->result(0));
auto output_size = phi::product(output_shape);

if (input_size == output_size) {
Expand All @@ -325,10 +325,9 @@ inline bool horizontal_relation(const std::shared_ptr<ir::Group>& first,
const OpPatternKind op_pattern_kind) {
// merge injective
auto merge_nodes_set = [](const std::shared_ptr<ir::Group>& group) {
std::unordered_set<::pir::Operation*> nodes_set = group->nodes_set;
std::unordered_set<::pir::Operation*> nodes_set = group->ops_set;
for (auto& sub_group : group->fused_sub_groups) {
nodes_set.insert(sub_group->nodes_set.begin(),
sub_group->nodes_set.end());
nodes_set.insert(sub_group->ops_set.begin(), sub_group->ops_set.end());
}
return nodes_set;
};
Expand Down Expand Up @@ -398,14 +397,14 @@ inline bool horizontal_with_injective(
if (!is_same_size(first, second)) {
return false;
}
return horizontal_relation(first, second, kInjective);
return horizontal_relation(first, second, OpPatternKind::kInjective);
}

inline bool injective_horizontal_with_reduce(
const std::shared_ptr<ir::Group>& first,
const std::shared_ptr<ir::Group>& second) {
// check injective with injective.
if (!horizontal_relation(first, second, kInjective)) {
if (!horizontal_relation(first, second, OpPatternKind::kInjective)) {
return false;
}
return elementwise_fuse_reduce(first, second);
Expand All @@ -424,8 +423,8 @@ inline bool reduce_fuse_broadcast(const std::shared_ptr<ir::Group>& first,
// each reducer and its consumers with type of Broadcast needs to meet. It is
// required that each consumer of type Broadcast meet the same shape after
// broadcast as before reduce.
for (auto& node_in_master : first->master_nodes) {
if (GetOpKind(node_in_master->name()) != kReduction) {
for (auto& node_in_master : first->master_ops) {
if (GetOpKind(node_in_master->name()) != OpPatternKind::kReduction) {
continue;
}
::pir::Operation* reducer = node_in_master;
Expand Down Expand Up @@ -480,8 +479,8 @@ inline bool reduce_fuse_broadcast(const std::shared_ptr<ir::Group>& first,
visited_set.insert(consumer);
candidates.push(consumer);
}
if (GetOpKind(consumer->name()) == kBroadcast &&
second->NodeSet().find(consumer) != second->NodeSet().end()) {
if (GetOpKind(consumer->name()) == OpPatternKind::kBroadcast &&
second->OpSet().find(consumer) != second->OpSet().end()) {
broadcasters.insert(consumer);
}
}
Expand Down Expand Up @@ -543,17 +542,17 @@ inline bool reduce_fuse_reduce(const std::shared_ptr<ir::Group>& first,
return false;
}
::pir::Operation* reducer_0 = nullptr;
for (auto& reducer : first->master_nodes) {
if (GetOpKind(reducer->name()) == kReduction) {
for (auto& reducer : first->master_ops) {
if (GetOpKind(reducer->name()) == OpPatternKind::kReduction) {
reducer_0 = reducer;
break;
}
}
// CHECK(reducer_0) << "Can't find reduce op in group " << first->group_id;

::pir::Operation* reducer_1 = nullptr;
for (auto& reducer : second->master_nodes) {
if (GetOpKind(reducer->name()) == kReduction) {
for (auto& reducer : second->master_ops) {
if (GetOpKind(reducer->name()) == OpPatternKind::kReduction) {
reducer_1 = reducer;
break;
}
Expand Down Expand Up @@ -594,8 +593,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr<ir::Group>& first,
reducer_0_reduce_dim == reducer_1_reduce_dim) {
auto shared_size = 0;
for (auto& fusion_group : {first, second}) {
for (auto* master : fusion_group->master_nodes) {
if (GetOpKind(master->name()) == kReduction) {
for (auto* master : fusion_group->master_ops) {
if (GetOpKind(master->name()) == OpPatternKind::kReduction) {
shared_size += GetSharedSize(master);
}
}
Expand All @@ -615,8 +614,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr<ir::Group>& first,
reducer_0_reduce_dim == reducer_1_reduce_dim) {
auto shared_size = 0;
for (auto& fusion_group : {first, second}) {
for (auto* master : fusion_group->master_nodes) {
if (GetOpKind(master->name()) == kReduction) {
for (auto* master : fusion_group->master_ops) {
if (GetOpKind(master->name()) == OpPatternKind::kReduction) {
shared_size += GetSharedSize(master);
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/transforms/op_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class OpGroup {
// group.WalkOpNodes(get_reduce_op);
void WalkOpNodes(
const std::function<void(const OpNode&)>& VisitOpNode) const {
group_.lock()->WalkNodes(
group_.lock()->WalkOps(
[&](::pir::Operation* node) { VisitOpNode(OpNode(node)); });
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/dialect/operator/transforms/op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class OpNode {

OpPatternKind kind() const {
auto kind = GetOpKind(node_->name());
if (kind == kBroadcast) {
if (kind == OpPatternKind::kBroadcast) {
// As binary op was defined as broadcast, actually it should be
// element-wise.
if (node_->name() != "broadcast_to") {
return kElementWise;
return OpPatternKind::kElementWise;
}
}
return kind;
Expand Down
Loading

0 comments on commit cf67c51

Please sign in to comment.