Skip to content

Commit

Permalink
[CINN] Add ReductionFactoring rule (PaddlePaddle#57569)
Browse files Browse the repository at this point in the history
Add ReductionFactoring rule
  • Loading branch information
BiynXu authored and Frida-a committed Oct 14, 2023
1 parent fcdbc4b commit bb548a4
Show file tree
Hide file tree
Showing 16 changed files with 499 additions and 29 deletions.
35 changes: 35 additions & 0 deletions paddle/cinn/auto_schedule/analysis/analyze_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,40 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
return new_func;
}

std::unordered_set<std::string> GetReduceLoopVarNames(const ir::Expr block) {
const ir::ScheduleBlockRealize* block_realize =
block.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_realize);
const ir::ScheduleBlock* block_node =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(block_node);
std::vector<ir::Expr> iter_values = block_realize->iter_values;
std::vector<ir::Var> iter_vars = block_node->iter_vars;

std::unordered_set<std::string> reduce_loop_var;
for (int i = 0; i < iter_vars.size(); ++i) {
if (iter_vars[i]->is_reduce_axis) {
ir::ir_utils::CollectIRNodesWithoutTensor(
iter_values[i], [&](const ir::Expr* x) {
if (x->as_var()) {
reduce_loop_var.insert(x->as_var_ref()->name);
}
return false;
});
}
}
return reduce_loop_var;
}

std::string GetBlockName(const ir::Expr block) {
const ir::ScheduleBlockRealize* block_realize =
block.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_realize);
const ir::ScheduleBlock* block_node =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(block_node);
return block_node->name;
}

} // namespace auto_schedule
} // namespace cinn
10 changes: 10 additions & 0 deletions paddle/cinn/auto_schedule/analysis/analyze_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,15 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
const ir::LoweredFunc& old_func,
ir::Expr& body); // NOLINT

/**
* Get loop var names of reduce axis
*/
std::unordered_set<std::string> GetReduceLoopVarNames(const ir::Expr block);

/**
* Get name of a ScheduleBlock
*/
std::string GetBlockName(const ir::Expr block);

} // namespace auto_schedule
} // namespace cinn
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ TEST_F(TestCooperativeProcess, Matmul) {
{
i0, i1 = axis.bind(((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1)))
{
temp_matmul_out__reduce_init[((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1))] = 0.00000000f
temp_matmul_out__reduce_init[i0, i1] = 0.00000000f
}
}
}
Expand Down Expand Up @@ -181,7 +181,7 @@ TEST_F(TestCooperativeProcess, Matmul) {
{
i0_0, i1_0, i2 = axis.bind(((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1)), ((4 * reduce_k_0) + reduce_k_1))
{
temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] = (temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] + (X_reshape_shared_temp_buffer[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((4 * reduce_k_0) + reduce_k_1)] * Y_reshape_shared_temp_buffer[((4 * reduce_k_0) + reduce_k_1), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))]))
temp_matmul_out[i0_0, i1_0] = (temp_matmul_out[i0_0, i1_0] + (X_reshape_shared_temp_buffer[i0_0, i2] * Y_reshape_shared_temp_buffer[i2, i1_0]))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ gather_srcs(
auto_unroll.cc
multi_level_tiling.cc
skip_rule.cc
auto_bind.cc)
auto_bind.cc
reduction_factoring.cc)

if(WITH_TESTING)
cinn_cc_library(
Expand Down Expand Up @@ -51,3 +52,11 @@ endif()
#cinn_cc_test(test_auto_inline SRCS auto_inline_test.cc DEPS cinncore auto_gen_rule_test_helper)
cinn_cc_test(test_skip_rule SRCS skip_rule_test.cc DEPS cinncore)
cinn_cc_test(test_auto_unroll SRCS auto_unroll_test.cc DEPS cinncore)
cinn_cc_test(
test_reduction_factoring
SRCS
reduction_factoring_test.cc
DEPS
cinncore
auto_gen_rule_test_helper
test_program_builder)
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ TEST_F(TestMultiLevelTiling, Matmul) {
{
i0, i1 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)))
{
temp_matmul_out__reduce_init[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = 0.00000000f
temp_matmul_out__reduce_init[i0, i1] = 0.00000000f
}
}
}
Expand Down Expand Up @@ -308,10 +308,10 @@ TEST_F(TestMultiLevelTiling, Matmul) {
ScheduleBlock(temp_matmul_out_local_temp_buffer)
{
i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)))
read_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)], _X[i(undefined:undefined), reduce_k(0:32)], _Y[reduce_k(0:32), j(undefined:undefined)])
write_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)])
read_buffers(_temp_matmul_out[i0_0(0:32), i1_0(0:32)], _X[i0_0(0:32), i2(0:32)], _Y[i2(0:32), i1_0(0:32)])
write_buffers(_temp_matmul_out[i0_0(0:32), i1_0(0:32)])
{
temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = (temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] + (X_reshape_shared_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))] * Y_reshape_shared_temp_buffer[((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)), ((32 * j_1) + ((32 * j_2) + j_3))]))
temp_matmul_out_local_temp_buffer[i0_0, i1_0] = (temp_matmul_out_local_temp_buffer[i0_0, i1_0] + (X_reshape_shared_temp_buffer[i0_0, i2] * Y_reshape_shared_temp_buffer[i2, i1_0]))
}
}
}
Expand Down Expand Up @@ -453,7 +453,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
pad_temp_0[i0, i1, i2, i3] = select(((i3 < (1 + 16)) and ((i3 >= 1) and ((i2 < (1 + 16)) and (i2 >= 1)))), input[i0, i1, (i2 - 1), (i3 - 1)], -3.40282347e+38f)
}
}
}
Expand All @@ -477,7 +477,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
var_0__reduce_init[i0_0, i1_0, i2_0, i3_0] = -3.40282347e+38f
}
}
}
Expand Down Expand Up @@ -511,10 +511,10 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
ScheduleBlock(var_0_local_temp_buffer)
{
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
read_buffers(_var_0[i0_1(0:2), i1_1(0:8), i2_1(0:8), i3_1(0:8)], _pad_temp_0[i0_1(0:2), i1_1(0:8)])
write_buffers(_var_0[i0_1(0:2), i1_1(0:8), i2_1(0:8), i3_1(0:8)])
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
var_0_local_temp_buffer[i0_1, i1_1, i2_1, i3_1] = cinn_max(var_0_local_temp_buffer[i0_1, i1_1, i2_1, i3_1], pad_temp_0_shared_temp_buffer[i0_1, i1_1, ((2 * i2_1) + i4), ((2 * i3_1) + i5)])
}
}
}
Expand All @@ -533,7 +533,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
ScheduleBlock(var_0)
{
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((i_0_j_0_k_0_a_0_fused % 4) + (4 * ((i_j_k_a_fused / 2) % 2))) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h"

#include <glog/logging.h>

#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"

namespace cinn {
namespace auto_schedule {

bool ReductionFactoring::CanApply(const std::string& block_name,
ir::IRSchedule* ir_schedule) const {
ir::Expr block_expr = ir_schedule->GetBlock(block_name);
ir::ScheduleBlockRealize* block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_realize);
ir::ScheduleBlock* sch_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(sch_block);
AnalyzeScheduleBlockReadWriteBuffer(sch_block);

// 1. The block must have write buffer
if (sch_block->write_buffers.empty()) {
return false;
}

// 2. The block must have at least one reduce axis
const std::vector<ir::Var>& iter_vars = sch_block->iter_vars;
bool find_reduce_axis = false;
for (int i = 0; i < iter_vars.size(); ++i) {
if (iter_vars[i]->is_reduce_axis) {
find_reduce_axis = true;
break;
}
}
if (!find_reduce_axis) {
return false;
}

// 3. Each loop's body only contains one sub loop or block, except reduce_init
// block
std::vector<ir::Expr> loops = ir_schedule->GetLoops(block_name);
for (const ir::Expr& loop : loops) {
const ir::Expr& body = loop.As<ir::For>()->body;
if (body.As<ir::Block>()) {
if (body.As<ir::Block>()->stmts.size() == 1) {
if (body.As<ir::Block>()->stmts[0].As<ir::For>() == nullptr &&
body.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>() ==
nullptr) {
return false;
}
} else if (body.As<ir::Block>()->stmts.size() == 2) {
if (body.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>() ==
nullptr ||
!ir::IsReduceInitTensorName(
GetBlockName(body.As<ir::Block>()->stmts[0]))) {
return false;
}
if (body.As<ir::Block>()->stmts[1].As<ir::For>() == nullptr &&
body.As<ir::Block>()->stmts[1].As<ir::ScheduleBlockRealize>() ==
nullptr) {
return false;
}
} else {
return false;
}
} else if (body.As<ir::For>() || body.As<ir::ScheduleBlockRealize>()) {
continue;
} else {
return false;
}
}

return true;
}

RuleApplyType ReductionFactoring::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
return this->CanApply(block_name, &(state->ir_schedule))
? RuleApplyType::kApply
: RuleApplyType::kCannotApply;
}

std::vector<SearchState> ReductionFactoring::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy();
Apply(block_name, &(new_state->ir_schedule));
return {new_state};
}

void ReductionFactoring::Apply(const std::string& block_name,
ir::IRSchedule* ir_schedule) {
ir::Expr block = ir_schedule->GetBlock(block_name);
std::vector<ir::Expr> all_loops = ir_schedule->GetLoops(block_name);

std::vector<ir::Expr> new_loop_order;
size_t num_spatial_loops = 0;
size_t num_reduction_loops = 0;
// 1. Add all spatial loops
std::unordered_set<std::string> reduce_loop_var_names =
GetReduceLoopVarNames(block);
for (const ir::Expr& expr : all_loops) {
if (reduce_loop_var_names.count(expr.As<ir::For>()->loop_var->name) == 0) {
new_loop_order.push_back(expr);
++num_spatial_loops;
}
}
// 2. Add all reduction loops
for (const ir::Expr& expr : all_loops) {
if (reduce_loop_var_names.count(expr.As<ir::For>()->loop_var->name) > 0) {
new_loop_order.push_back(expr);
++num_reduction_loops;
}
}
if (num_reduction_loops == 0) {
return;
}
// 3. Reorder if new_loop_order differs from the original order
CHECK_EQ(all_loops.size(), new_loop_order.size());
for (int i = 0; i < all_loops.size(); ++i) {
if (all_loops[i].As<ir::For>()->loop_var->name !=
new_loop_order[i].As<ir::For>()->loop_var->name) {
ir_schedule->Reorder(new_loop_order);
break;
}
}

// TODO(BiynXu): After implementing the factorize_reduction schedule
// primitive, restore the following annotations. The factorize_reduction
// schedule primitive needs to support complex subscripts to support pre
// schedule transformations.

// // 4. Fuse all reduction loops
// ir::Expr fused_reduce_loop;
// if (num_reduction_loops > 1) {
// std::vector<int> reduction_loop_indices;
// for (int i = num_spatial_loops - 1; i < all_loops.size(); ++i) {
// reduction_loop_indices.push_back(i);
// }
// CHECK_EQ(reduction_loop_indices.size(), num_reduction_loops);
// fused_reduce_loop = ir_schedule->Fuse(block_name,
// reduction_loop_indices);
// } else {
// all_loops = ir_schedule->GetLoops(block_name);
// fused_reduce_loop = all_loops.back();
// }
// // 5. Split the reduction loop into 2 part
// int factor = 1;
// int extent = ir::GetLoopExtent(fused_reduce_loop);
// for (int i = ceil(sqrt(extent)); i >= 1; --i) {
// if (extent % i == 0) {
// factor = i;
// break;
// }
// }
// std::vector<cinn::ir::Expr> splited_reduction_loops =
// ir_schedule->Split(fused_reduce_loop, {-1, factor});
// // Apply FactorizeReduction
// LOG(INFO) << "before FactorizeReduction: " <<
// ir_schedule->GetModule().GetExprs()[0];
// ir_schedule->FactorizeReduction(splited_reduction_loops[0],
// num_spatial_loops);

// Apply rfactor
all_loops = ir_schedule->GetLoops(block_name);
ir_schedule->Rfactor(all_loops[num_spatial_loops], num_spatial_loops);
}

} // namespace auto_schedule
} // namespace cinn
Loading

0 comments on commit bb548a4

Please sign in to comment.