Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN+PIR]Support DoGroupSchedule for PIRCompiler #58399

Merged
merged 6 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program) {
auto ir_compiler =
new cinn::hlir::framework::PIRCompiler(*program, target, scope);
auto group1 =
std::make_shared<cinn::hlir::framework::pir::Group>(group->nodes);
std::make_shared<cinn::hlir::framework::pir::Group>(group->ops);
auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group1});
compiler_list.push_back(ir_compiler);
std::unordered_map<std::string, ::pir::Attribute> op_attrs{
Expand All @@ -133,14 +133,14 @@ std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program) {
};

// Generate jit kernel op input and output
auto vec_ins = GetBlockOutsideInput(group->nodes);
auto vec_ins = GetBlockOutsideInput(group->ops);

std::vector<pir::Value> vec_new_ins;
for (size_t i = 0; i < vec_ins.size(); ++i) {
vec_new_ins.push_back(value_map.at(vec_ins[i]));
}

auto vec_outs = GetBlockOutsideOutput(group->nodes);
auto vec_outs = GetBlockOutsideOutput(group->ops);

std::vector<pir::Type> vec_types;
for (auto& out : vec_outs) {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ static bool ReduceFuseReduce(const OpGroupPtr& first,
}
std::unique_ptr<cinn::dialect::ir::OpNode> reducer_0 = nullptr;
first.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) {
if (!reducer_0 && op.kind() == kReduction) {
if (!reducer_0 && op.kind() == OpPatternKind::kReduction) {
reducer_0.reset(new cinn::dialect::ir::OpNode(op));
}
});
CHECK(reducer_0) << "Can't find reduce op in group " << first.group_id();

std::unique_ptr<cinn::dialect::ir::OpNode> reducer_1 = nullptr;
second.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) {
if (!reducer_1 && op.kind() == kReduction) {
if (!reducer_1 && op.kind() == OpPatternKind::kReduction) {
reducer_1.reset(new cinn::dialect::ir::OpNode(op));
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ inline bool limit_args(const std::shared_ptr<ir::Group>& first,
const std::shared_ptr<ir::Group>& second) {
std::unordered_set<const ::pir::Operation*> args;
for (auto& group : {first, second}) {
for (auto node : group->input_nodes) {
args.insert(node.first);
for (auto iter : group->input_ops) {
args.insert(iter.first);
}
for (auto node : group->output_nodes) {
args.insert(node);
for (auto op : group->output_ops) {
args.insert(op);
}
}

Expand All @@ -66,8 +66,8 @@ inline bool is_same_shape(const std::shared_ptr<ir::Group>& first,
return false;
}

auto output_var_0 = GetValueShape((*first->master_nodes.begin())->result(0));
auto output_var_1 = GetValueShape((*second->master_nodes.begin())->result(0));
auto output_var_0 = GetValueShape((*first->master_ops.begin())->result(0));
auto output_var_1 = GetValueShape((*second->master_ops.begin())->result(0));
return output_var_0 == output_var_1;
}

Expand All @@ -77,8 +77,8 @@ inline bool is_same_size(const std::shared_ptr<ir::Group>& first,
return false;
}

auto output_var_0 = GetValueShape((*first->master_nodes.begin())->result(0));
auto output_var_1 = GetValueShape((*second->master_nodes.begin())->result(0));
auto output_var_0 = GetValueShape((*first->master_ops.begin())->result(0));
auto output_var_1 = GetValueShape((*second->master_ops.begin())->result(0));
if (output_var_0 == output_var_1) {
return true;
}
Expand All @@ -89,8 +89,8 @@ inline bool is_same_size(const std::shared_ptr<ir::Group>& first,
}

inline bool is_const_group(const std::shared_ptr<ir::Group>& group) {
return group->CollectNodes().size() == 1 &&
ConstantOps.count(group->CollectNodes()[0]->name());
return group->CollectOps().size() == 1 &&
ConstantOps.count(group->CollectOps()[0]->name());
}

inline bool elementwise_fuse_broadcast(
Expand All @@ -105,9 +105,9 @@ inline bool elementwise_fuse_broadcast(
return true;
}
// if first's output is not all in second's input
for (auto output : first->output_nodes) {
for (auto output : first->output_ops) {
return true;
if (!second->input_nodes.count(output)) {
if (!second->input_ops.count(output)) {
return false;
}

Expand All @@ -130,7 +130,7 @@ inline bool honrizontal_elementwise_fuse_reduce(
const std::shared_ptr<ir::Group>& first,
const std::shared_ptr<ir::Group>& second) {
std::shared_ptr<ir::Group> ele_group, reduce_group;
if (first->op_pattern_kind == kReduction) {
if (first->op_pattern_kind == OpPatternKind::kReduction) {
ele_group = second;
reduce_group = first;
} else {
Expand All @@ -143,10 +143,10 @@ inline bool honrizontal_elementwise_fuse_reduce(
}

auto ele_node_shape =
GetValueShape((*ele_group->master_nodes.begin())->result(0));
GetValueShape((*ele_group->master_ops.begin())->result(0));
int32_t size_ele = phi::product(ele_node_shape);
// TODO(phlrain): seems extrame danger herem, why compare multi Master Node?
for (auto* master : reduce_group->master_nodes) {
for (auto* master : reduce_group->master_ops) {
auto master_node_shape = GetValueShape(master->result(0));
int32_t size_master = phi::product(master_node_shape);
if (size_ele == size_master) {
Expand All @@ -169,9 +169,9 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr<ir::Group>& first,

// if reduce nodes not in consumers of first group
std::queue<::pir::Operation*> candidates;
std::unordered_set<::pir::Operation*> first_node_set = first->NodeSet();
std::unordered_set<::pir::Operation*> second_node_set = second->NodeSet();
for (const auto& pair : second->input_nodes) {
std::unordered_set<::pir::Operation*> first_node_set = first->OpSet();
std::unordered_set<::pir::Operation*> second_node_set = second->OpSet();
for (const auto& pair : second->input_ops) {
if (first_node_set.find(pair.first) != first_node_set.end()) {
candidates.push(pair.first);
}
Expand All @@ -195,15 +195,15 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr<ir::Group>& first,
visited.insert(consumer);
candidates.push(consumer);
}
if (second->master_nodes.count(consumer)) {
if (second->master_ops.count(consumer)) {
masters_in_consumers.insert(consumer);
}
}
}
if (!masters_in_consumers.empty()) {
bool flag = true;
auto first_node_shape =
GetValueShape((*first->master_nodes.begin())->result(0));
GetValueShape((*first->master_ops.begin())->result(0));
int32_t size_first = phi::product(first_node_shape);

for (::pir::Operation* master : masters_in_consumers) {
Expand All @@ -221,8 +221,8 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr<ir::Group>& first,

// if reduce using block_reduce, can't fuse producer.
::pir::Operation* reducer = nullptr;
for (auto& node : second->master_nodes) {
if (GetOpKind(node->name()) == kReduction) {
for (auto& node : second->master_ops) {
if (GetOpKind(node->name()) == OpPatternKind::kReduction) {
reducer = node;
break;
}
Expand Down Expand Up @@ -289,8 +289,8 @@ inline bool broadcast_fuse_reduce(const std::shared_ptr<ir::Group>& first,
return true;
}
::pir::Operation* reducer = nullptr;
for (auto& node : second->master_nodes) {
if (GetOpKind(node->name()) == kReduction) {
for (auto& node : second->master_ops) {
if (GetOpKind(node->name()) == OpPatternKind::kReduction) {
reducer = node;
break;
}
Expand All @@ -300,7 +300,7 @@ inline bool broadcast_fuse_reduce(const std::shared_ptr<ir::Group>& first,
auto input_shape = GetValueShape(reducer->operand_source(0));
auto input_size = phi::product(input_shape);

auto output_shape = GetValueShape((*first->master_nodes.begin())->result(0));
auto output_shape = GetValueShape((*first->master_ops.begin())->result(0));
auto output_size = phi::product(output_shape);

if (input_size == output_size) {
Expand All @@ -325,10 +325,9 @@ inline bool horizontal_relation(const std::shared_ptr<ir::Group>& first,
const OpPatternKind op_pattern_kind) {
// merge injective
auto merge_nodes_set = [](const std::shared_ptr<ir::Group>& group) {
std::unordered_set<::pir::Operation*> nodes_set = group->nodes_set;
std::unordered_set<::pir::Operation*> nodes_set = group->ops_set;
for (auto& sub_group : group->fused_sub_groups) {
nodes_set.insert(sub_group->nodes_set.begin(),
sub_group->nodes_set.end());
nodes_set.insert(sub_group->ops_set.begin(), sub_group->ops_set.end());
}
return nodes_set;
};
Expand Down Expand Up @@ -398,14 +397,14 @@ inline bool horizontal_with_injective(
if (!is_same_size(first, second)) {
return false;
}
return horizontal_relation(first, second, kInjective);
return horizontal_relation(first, second, OpPatternKind::kInjective);
}

inline bool injective_horizontal_with_reduce(
const std::shared_ptr<ir::Group>& first,
const std::shared_ptr<ir::Group>& second) {
// check injective with injective.
if (!horizontal_relation(first, second, kInjective)) {
if (!horizontal_relation(first, second, OpPatternKind::kInjective)) {
return false;
}
return elementwise_fuse_reduce(first, second);
Expand All @@ -424,8 +423,8 @@ inline bool reduce_fuse_broadcast(const std::shared_ptr<ir::Group>& first,
// each reducer and its consumers with type of Broadcast needs to meet. It is
// required that each consumer of type Broadcast meet the same shape after
// broadcast as before reduce.
for (auto& node_in_master : first->master_nodes) {
if (GetOpKind(node_in_master->name()) != kReduction) {
for (auto& node_in_master : first->master_ops) {
if (GetOpKind(node_in_master->name()) != OpPatternKind::kReduction) {
continue;
}
::pir::Operation* reducer = node_in_master;
Expand Down Expand Up @@ -480,8 +479,8 @@ inline bool reduce_fuse_broadcast(const std::shared_ptr<ir::Group>& first,
visited_set.insert(consumer);
candidates.push(consumer);
}
if (GetOpKind(consumer->name()) == kBroadcast &&
second->NodeSet().find(consumer) != second->NodeSet().end()) {
if (GetOpKind(consumer->name()) == OpPatternKind::kBroadcast &&
second->OpSet().find(consumer) != second->OpSet().end()) {
broadcasters.insert(consumer);
}
}
Expand Down Expand Up @@ -543,17 +542,17 @@ inline bool reduce_fuse_reduce(const std::shared_ptr<ir::Group>& first,
return false;
}
::pir::Operation* reducer_0 = nullptr;
for (auto& reducer : first->master_nodes) {
if (GetOpKind(reducer->name()) == kReduction) {
for (auto& reducer : first->master_ops) {
if (GetOpKind(reducer->name()) == OpPatternKind::kReduction) {
reducer_0 = reducer;
break;
}
}
// CHECK(reducer_0) << "Can't find reduce op in group " << first->group_id;

::pir::Operation* reducer_1 = nullptr;
for (auto& reducer : second->master_nodes) {
if (GetOpKind(reducer->name()) == kReduction) {
for (auto& reducer : second->master_ops) {
if (GetOpKind(reducer->name()) == OpPatternKind::kReduction) {
reducer_1 = reducer;
break;
}
Expand Down Expand Up @@ -594,8 +593,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr<ir::Group>& first,
reducer_0_reduce_dim == reducer_1_reduce_dim) {
auto shared_size = 0;
for (auto& fusion_group : {first, second}) {
for (auto* master : fusion_group->master_nodes) {
if (GetOpKind(master->name()) == kReduction) {
for (auto* master : fusion_group->master_ops) {
if (GetOpKind(master->name()) == OpPatternKind::kReduction) {
shared_size += GetSharedSize(master);
}
}
Expand All @@ -615,8 +614,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr<ir::Group>& first,
reducer_0_reduce_dim == reducer_1_reduce_dim) {
auto shared_size = 0;
for (auto& fusion_group : {first, second}) {
for (auto* master : fusion_group->master_nodes) {
if (GetOpKind(master->name()) == kReduction) {
for (auto* master : fusion_group->master_ops) {
if (GetOpKind(master->name()) == OpPatternKind::kReduction) {
shared_size += GetSharedSize(master);
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/transforms/op_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class OpGroup {
// group.WalkOpNodes(get_reduce_op);
void WalkOpNodes(
const std::function<void(const OpNode&)>& VisitOpNode) const {
group_.lock()->WalkNodes(
group_.lock()->WalkOps(
[&](::pir::Operation* node) { VisitOpNode(OpNode(node)); });
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/dialect/operator/transforms/op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class OpNode {

OpPatternKind kind() const {
auto kind = GetOpKind(node_->name());
if (kind == kBroadcast) {
if (kind == OpPatternKind::kBroadcast) {
// As binary op was defined as broadcast, actually it should be
// element-wise.
if (node_->name() != "broadcast_to") {
return kElementWise;
return OpPatternKind::kElementWise;
}
}
return kind;
Expand Down
Loading