Skip to content

Commit

Permalink
Cinn trivalop fuse (PaddlePaddle#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Mar 12, 2024
1 parent 2a6a72a commit efe91cc
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions paddle/cinn/hlir/framework/pir/trivial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ std::set<Expr> GetStoreFromBody(const ir::Expr& body) {
return store_tensor_exprs;
}

<<<<<<< HEAD
bool CheckIterEq(std::vector<ir::Var> up_iter, std::vector<ir::Var> down_iter) {
TODO
}
Expand Down Expand Up @@ -415,6 +414,14 @@ struct FusionNode {
}

bool IsTrivial() { return std::holds_alternative<TrivialOp>(fusible_op); }

ir::Expr GetExpr(){
if (IsTrivial()){
return std::get<TrivialOp>(fusible_op).GetFuncBody();
}else{
return std::get<ReduceOp>(fusible_op).GetFuncBody();
}
}
};

TrivialOp TTFusion(TrivialOp upstream, TrivialOp downstream) {
Expand Down Expand Up @@ -648,7 +655,7 @@ struct FusionGraph {
std::vector<ir::Expr> GetExprResults() {
std::vector<ir::Expr> output_exprs;
for (const auto& node : all_fusion_nodes_) {
output_exprs.emplace_back(node->op_compute_body);
output_exprs.emplace_back(node->GetExpr());
}
return output_exprs;
}
Expand Down

0 comments on commit efe91cc

Please sign in to comment.