Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Add ReductionFactoring rule #57569

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions paddle/cinn/auto_schedule/analysis/analyze_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,40 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
return new_func;
}

std::unordered_set<std::string> GetReduceLoopVarNames(const ir::Expr block) {
const ir::ScheduleBlockRealize* block_realize =
block.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_realize);
const ir::ScheduleBlock* block_node =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(block_node);
std::vector<ir::Expr> iter_values = block_realize->iter_values;
std::vector<ir::Var> iter_vars = block_node->iter_vars;

std::unordered_set<std::string> reduce_loop_var;
for (int i = 0; i < iter_vars.size(); ++i) {
if (iter_vars[i]->is_reduce_axis) {
ir::ir_utils::CollectIRNodesWithoutTensor(
iter_values[i], [&](const ir::Expr* x) {
if (x->as_var()) {
reduce_loop_var.insert(x->as_var_ref()->name);
}
return false;
});
}
}
return reduce_loop_var;
}

std::string GetBlockName(const ir::Expr block) {
const ir::ScheduleBlockRealize* block_realize =
block.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_realize);
const ir::ScheduleBlock* block_node =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(block_node);
return block_node->name;
}

} // namespace auto_schedule
} // namespace cinn
10 changes: 10 additions & 0 deletions paddle/cinn/auto_schedule/analysis/analyze_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,15 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
const ir::LoweredFunc& old_func,
ir::Expr& body); // NOLINT

/**
* Get loop var names of reduce axis
*/
std::unordered_set<std::string> GetReduceLoopVarNames(const ir::Expr block);

/**
* Get name of a ScheduleBlock
*/
std::string GetBlockName(const ir::Expr block);

} // namespace auto_schedule
} // namespace cinn
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ TEST_F(TestCooperativeProcess, Matmul) {
{
i0, i1 = axis.bind(((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1)))
{
temp_matmul_out__reduce_init[((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1))] = 0.00000000f
temp_matmul_out__reduce_init[i0, i1] = 0.00000000f
}
}
}
Expand Down Expand Up @@ -181,7 +181,7 @@ TEST_F(TestCooperativeProcess, Matmul) {
{
i0_0, i1_0, i2 = axis.bind(((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1)), ((4 * reduce_k_0) + reduce_k_1))
{
temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] = (temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] + (X_reshape_shared_temp_buffer[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((4 * reduce_k_0) + reduce_k_1)] * Y_reshape_shared_temp_buffer[((4 * reduce_k_0) + reduce_k_1), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))]))
temp_matmul_out[i0_0, i1_0] = (temp_matmul_out[i0_0, i1_0] + (X_reshape_shared_temp_buffer[i0_0, i2] * Y_reshape_shared_temp_buffer[i2, i1_0]))
BiynXu marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ gather_srcs(
auto_unroll.cc
multi_level_tiling.cc
skip_rule.cc
auto_bind.cc)
auto_bind.cc
reduction_factoring.cc)

if(WITH_TESTING)
cinn_cc_library(
Expand Down Expand Up @@ -51,3 +52,11 @@ endif()
#cinn_cc_test(test_auto_inline SRCS auto_inline_test.cc DEPS cinncore auto_gen_rule_test_helper)
cinn_cc_test(test_skip_rule SRCS skip_rule_test.cc DEPS cinncore)
cinn_cc_test(test_auto_unroll SRCS auto_unroll_test.cc DEPS cinncore)
cinn_cc_test(
test_reduction_factoring
SRCS
reduction_factoring_test.cc
DEPS
cinncore
auto_gen_rule_test_helper
test_program_builder)
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ TEST_F(TestMultiLevelTiling, Matmul) {
{
i0, i1 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)))
{
temp_matmul_out__reduce_init[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = 0.00000000f
temp_matmul_out__reduce_init[i0, i1] = 0.00000000f
}
}
}
Expand Down Expand Up @@ -308,10 +308,10 @@ TEST_F(TestMultiLevelTiling, Matmul) {
ScheduleBlock(temp_matmul_out_local_temp_buffer)
{
i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)))
read_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)], _X[i(undefined:undefined), reduce_k(undefined:undefined)], _Y[reduce_k(undefined:undefined), j(undefined:undefined)])
write_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)])
read_buffers(_temp_matmul_out[i0_0(0:32), i1_0(0:32)], _X[i0_0(0:32), i2(0:32)], _Y[i2(0:32), i1_0(0:32)])
write_buffers(_temp_matmul_out[i0_0(0:32), i1_0(0:32)])
{
temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = (temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] + (X_reshape_shared_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))] * Y_reshape_shared_temp_buffer[((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)), ((32 * j_1) + ((32 * j_2) + j_3))]))
temp_matmul_out_local_temp_buffer[i0_0, i1_0] = (temp_matmul_out_local_temp_buffer[i0_0, i1_0] + (X_reshape_shared_temp_buffer[i0_0, i2] * Y_reshape_shared_temp_buffer[i2, i1_0]))
}
}
}
Expand Down Expand Up @@ -453,7 +453,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
pad_temp_0[i0, i1, i2, i3] = select(((i3 < (1 + 16)) and ((i3 >= 1) and ((i2 < (1 + 16)) and (i2 >= 1)))), input[i0, i1, (i2 - 1), (i3 - 1)], -3.40282347e+38f)
}
}
}
Expand All @@ -477,7 +477,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
var_0__reduce_init[i0_0, i1_0, i2_0, i3_0] = -3.40282347e+38f
}
}
}
Expand Down Expand Up @@ -511,10 +511,10 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
ScheduleBlock(var_0_local_temp_buffer)
{
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
read_buffers(_var_0[i0_1(0:2), i1_1(0:8), i2_1(0:8), i3_1(0:8)], _pad_temp_0[i0_1(0:2), i1_1(0:8)])
write_buffers(_var_0[i0_1(0:2), i1_1(0:8), i2_1(0:8), i3_1(0:8)])
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
var_0_local_temp_buffer[i0_1, i1_1, i2_1, i3_1] = cinn_max(var_0_local_temp_buffer[i0_1, i1_1, i2_1, i3_1], pad_temp_0_shared_temp_buffer[i0_1, i1_1, ((2 * i2_1) + i4), ((2 * i3_1) + i5)])
}
}
}
Expand All @@ -533,7 +533,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
ScheduleBlock(var_0)
{
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((i_0_j_0_k_0_a_0_fused % 4) + (4 * ((i_j_k_a_fused / 2) % 2))) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h"

#include <glog/logging.h>

#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"

namespace cinn {
namespace auto_schedule {

bool ReductionFactoring::CanApply(const std::string& block_name,
ir::IRSchedule* ir_schedule) const {
ir::Expr block_expr = ir_schedule->GetBlock(block_name);
ir::ScheduleBlockRealize* block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_realize);
ir::ScheduleBlock* sch_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(sch_block);
AnalyzeScheduleBlockReadWriteBuffer(sch_block);

// 1. The block must have write buffer
if (sch_block->write_buffers.empty()) {
return false;
}

// 2. The block must have at least one reduce axis
const std::vector<ir::Var>& iter_vars = sch_block->iter_vars;
bool find_reduce_axis = false;
for (int i = 0; i < iter_vars.size(); ++i) {
if (iter_vars[i]->is_reduce_axis) {
find_reduce_axis = true;
BiynXu marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
if (!find_reduce_axis) {
return false;
}

// 3. Each loop's body only contains one sub loop or block, except reduce_init
// block
std::vector<ir::Expr> loops = ir_schedule->GetLoops(block_name);
for (const ir::Expr& loop : loops) {
const ir::Expr& body = loop.As<ir::For>()->body;
if (body.As<ir::Block>()) {
if (body.As<ir::Block>()->stmts.size() == 1) {
if (body.As<ir::Block>()->stmts[0].As<ir::For>() == nullptr &&
body.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>() ==
nullptr) {
return false;
}
} else if (body.As<ir::Block>()->stmts.size() == 2) {
if (body.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>() ==
nullptr ||
!ir::IsReduceInitTensorName(
GetBlockName(body.As<ir::Block>()->stmts[0]))) {
return false;
}
if (body.As<ir::Block>()->stmts[1].As<ir::For>() == nullptr &&
body.As<ir::Block>()->stmts[1].As<ir::ScheduleBlockRealize>() ==
nullptr) {
return false;
}
} else {
return false;
}
} else if (body.As<ir::For>() || body.As<ir::ScheduleBlockRealize>()) {
continue;
} else {
return false;
}
}

return true;
}

RuleApplyType ReductionFactoring::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
return this->CanApply(block_name, &(state->ir_schedule))
? RuleApplyType::kApply
: RuleApplyType::kCannotApply;
}

std::vector<SearchState> ReductionFactoring::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy();
BiynXu marked this conversation as resolved.
Show resolved Hide resolved
Apply(block_name, &(new_state->ir_schedule));
return {new_state};
}

void ReductionFactoring::Apply(const std::string& block_name,
ir::IRSchedule* ir_schedule) {
ir::Expr block = ir_schedule->GetBlock(block_name);
std::vector<ir::Expr> all_loops = ir_schedule->GetLoops(block_name);

std::vector<ir::Expr> new_loop_order;
size_t num_spatial_loops = 0;
size_t num_reduction_loops = 0;
// 1. Add all spatial loops
std::unordered_set<std::string> reduce_loop_var_names =
GetReduceLoopVarNames(block);
for (const ir::Expr& expr : all_loops) {
if (reduce_loop_var_names.count(expr.As<ir::For>()->loop_var->name) == 0) {
new_loop_order.push_back(expr);
++num_spatial_loops;
}
}
// 2. Add all reduction loops
for (const ir::Expr& expr : all_loops) {
if (reduce_loop_var_names.count(expr.As<ir::For>()->loop_var->name) > 0) {
new_loop_order.push_back(expr);
++num_reduction_loops;
}
}
if (num_reduction_loops == 0) {
return;
}
// 3. Reorder if new_loop_order differs from the original order
CHECK_EQ(all_loops.size(), new_loop_order.size());
for (int i = 0; i < all_loops.size(); ++i) {
if (all_loops[i].As<ir::For>()->loop_var->name !=
new_loop_order[i].As<ir::For>()->loop_var->name) {
ir_schedule->Reorder(new_loop_order);
break;
}
}

// TODO(BiynXu): After implementing the factorize_reduction schedule
// primitive, restore the following annotations. The factorize_reduction
// schedule primitive needs to support complex subscripts to support pre
// schedule transformations.

// // 4. Fuse all reduction loops
// ir::Expr fused_reduce_loop;
// if (num_reduction_loops > 1) {
// std::vector<int> reduction_loop_indices;
// for (int i = num_spatial_loops - 1; i < all_loops.size(); ++i) {
// reduction_loop_indices.push_back(i);
// }
// CHECK_EQ(reduction_loop_indices.size(), num_reduction_loops);
// fused_reduce_loop = ir_schedule->Fuse(block_name,
// reduction_loop_indices);
// } else {
// all_loops = ir_schedule->GetLoops(block_name);
// fused_reduce_loop = all_loops.back();
// }
// // 5. Split the reduction loop into 2 part
// int factor = 1;
// int extent = ir::GetLoopExtent(fused_reduce_loop);
// for (int i = ceil(sqrt(extent)); i >= 1; --i) {
// if (extent % i == 0) {
// factor = i;
// break;
// }
// }
// std::vector<cinn::ir::Expr> splited_reduction_loops =
// ir_schedule->Split(fused_reduce_loop, {-1, factor});
// // Apply FactorizeReduction
// LOG(INFO) << "before FactorizeReduction: " <<
// ir_schedule->GetModule().GetExprs()[0];
// ir_schedule->FactorizeReduction(splited_reduction_loops[0],
// num_spatial_loops);

// Apply rfactor
all_loops = ir_schedule->GetLoops(block_name);
ir_schedule->Rfactor(all_loops[num_spatial_loops], num_spatial_loops);
}

} // namespace auto_schedule
} // namespace cinn
Loading