Skip to content

Commit

Permalink
Cinn trivalop fuse (PaddlePaddle#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Mar 12, 2024
1 parent 1adce1e commit 185f288
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions paddle/cinn/hlir/framework/pir/trivial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,11 @@ struct FusionNode {

explicit FusionNode(FusibleOp fusible_op) : fusible_op(fusible_op) {}

static std::string GetTensorCounter() {
static int i = 0;
return std::to_string(i++);
}

void replace_topo_structure_of_fused_nodes(FusionNode* fused_up_node,
FusionNode* fused_down_node) {
upstream.insert(fused_up_node->upstream.begin(),
Expand Down Expand Up @@ -535,17 +540,30 @@ std::vector<ReduceOp> TransformReduceLoopRange(ReduceOp upstream,
const auto& load_upstream_expr =
downstream.GetEachTensorLoadExpr(upstream.GetOutputTensor());
std::vector<ReduceOp> results;
ir::Tensor downstream_output_tensor = downstream.GetOutputTensor();
const auto create_new_tensor = [&](const ir::Tensor& downstream_load_tensor) {
return ir::Tensor(
downstream_load_tensor->name + FusionNode::GetTensorCounter(),
downstream_load_tensor->type(),
downstream_output_tensor.self()->sym_shape,
downstream_load_tensor.self()->sym_domain,
downstream_load_tensor.self()->operation,
downstream_output_tensor.self()->reduce_axis);
};

for (const auto& load_tensor : load_upstream_expr) {
const auto& new_tensor = create_new_tensor(
*(load_tensor.As<ir::Load>()->tensor.As<ir::Tensor>()));
ir::Expr new_reduce = CreateReduceExpr(
downstream,
ComposeUtils::CopyedReplaceExpr(upstream.GetFuncBody(),
ComposeUtils::CopyedReplaceExpr(upstream.GetComputeExpr(),
upstream.GetOutputIters(),
load_tensor.As<ir::Load>()->indices),
upstream.GetInitExpr(),
new_tensor);
ComposeUtils::MappingTargetExprToDestExprMutator(
load_tensor.As<ir::Load>()->tensor,
new_tensor)(downstream.GetFuncBody());
Expr(new_tensor))(&downstream.GetFuncBody());
results.emplace_back(new_reduce);
}

Expand Down

0 comments on commit 185f288

Please sign in to comment.