diff --git a/paddle/cinn/backends/ir_schedule_test.cc b/paddle/cinn/backends/ir_schedule_test.cc index 5ea30c6951d24..2923c8dc9fe7a 100644 --- a/paddle/cinn/backends/ir_schedule_test.cc +++ b/paddle/cinn/backends/ir_schedule_test.cc @@ -2310,6 +2310,270 @@ void test_rfactor(void* _args, int32_t num_args) ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); } +TEST(IrSchedule, factorize_reduction) { + Context::Global().ResetNameId(); + Expr M(3); + Expr N(4); + Expr K(5); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, K}); + Var j(4, "j0"); + Var k(5, "k0"); + auto B = Compute( + {M}, + [&](Var i) { + return lang::ReduceSum(A(i, j, k), {j, k}); + }, + "B"); + + auto stages = CreateStages({A, B}); + auto func = cinn::lang::LowerVec("test_factorize_reduction", + stages, + {A, B}, + {}, + {}, + nullptr, + target, + true); + CHECK(!func.empty()); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("B"); + CHECK_EQ(loops.size(), 3U); + auto new_rf_tensor = ir_sch.FactorizeReduction(loops[1], 0); + auto* new_rf_tensor_ref = new_rf_tensor.As(); + CHECK(new_rf_tensor_ref); + CHECK(new_rf_tensor_ref->buffer.defined()); + func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer); + func[0]->PrepareBufferCastExprs(); + std::string origin = utils::GetStreamCnt(func[0]); + LOG(INFO) << origin; + EXPECT_EQ(origin, utils::Trim(R"ROC( +function test_factorize_reduction (_A, _B) +{ + ScheduleBlock(root) + { + { + serial for (i, 0, 3) + { + serial for (j0, 0, 4) + { + ScheduleBlock(B_rf__reduce_init) + { + vj0, i0_0 = axis.bind(j0, i) + B_rf__reduce_init[vj0, i0_0] = 0.00000000f + } + serial for (k0, 0, 5) + { + ScheduleBlock(B_rf) + { + vj0, i0_0, i2 = axis.bind(j0, i, k0) + B_rf[vj0, i0_0] = (B_rf[vj0, i0_0] + A[i0_0, vj0, i2]) + } + } + } + } + serial for (i, 0, 3) + { + ScheduleBlock(B__reduce_init) + { + i0_0 = axis.bind(i) + B__reduce_init[i0_0] = 0.00000000f + } + serial for (j0, 0, 4) + { + ScheduleBlock(B) + { + vj0, i0_0 = axis.bind(j0, i) + B[i0_0] = (B[i0_0] + B_rf[vj0, i0_0]) + } + } + } + } + } +} +)ROC")); +} + +TEST(IrSchedule, factorize_reduction1) { + Context::Global().ResetNameId(); + Expr M(3); + Expr N(4); + Expr K(5); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, K}); + Var j(4, "j0"); + Var k(5, "k0"); + auto B = Compute( + {M}, + [&](Var i) { + return lang::ReduceSum(A(i, j, k), {j, k}); + }, + "B"); + + auto stages = CreateStages({A, B}); + auto func = cinn::lang::LowerVec("test_factorize_reduction", + stages, + {A, B}, + {}, + {}, + nullptr, + target, + true); + CHECK(!func.empty()); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("B"); + CHECK_EQ(loops.size(), 3U); + auto new_rf_tensor = ir_sch.FactorizeReduction(loops[1], 1); + auto* new_rf_tensor_ref = new_rf_tensor.As(); + CHECK(new_rf_tensor_ref); + CHECK(new_rf_tensor_ref->buffer.defined()); + func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer); + func[0]->PrepareBufferCastExprs(); + std::string origin = utils::GetStreamCnt(func[0]); + LOG(INFO) << origin; + EXPECT_EQ(origin, utils::Trim(R"ROC( +function test_factorize_reduction (_A, _B) +{ + ScheduleBlock(root) + { + { + serial for (i, 0, 3) + { + serial for (j0, 0, 4) + { + ScheduleBlock(B_rf__reduce_init) + { + vj0, i0_0 = axis.bind(j0, i) + B_rf__reduce_init[i0_0, vj0] = 0.00000000f + } + serial for (k0, 0, 5) + { + ScheduleBlock(B_rf) + { + vj0, i0_0, i2 = axis.bind(j0, i, k0) + B_rf[i0_0, vj0] = (B_rf[i0_0, vj0] + A[i0_0, vj0, i2]) + } + } + } + } + serial for (i, 0, 3) + { + ScheduleBlock(B__reduce_init) + { + i0_0 = axis.bind(i) + B__reduce_init[i0_0] = 0.00000000f + } + serial for (j0, 0, 4) + { + ScheduleBlock(B) + { + vj0, i0_0 = axis.bind(j0, i) + B[i0_0] = (B[i0_0] + B_rf[i0_0, vj0]) + } + } + } + } + } +} +)ROC")); +} + +TEST(IrSchedule, factorize_reduction2) { + Context::Global().ResetNameId(); + Expr M(3); + Expr N(4); + Expr K(5); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N * K}); + Var j(4 * 5, "j0"); + auto B = Compute( + {M}, [&](Var i) { return lang::ReduceSum(A(i, j), {j}); }, "B"); + + auto stages = CreateStages({A, B}); + auto func = cinn::lang::LowerVec("test_factorize_reduction", + stages, + {A, B}, + {}, + {}, + nullptr, + target, + true); + CHECK(!func.empty()); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("B"); + CHECK_EQ(loops.size(), 2U); + auto splited_loops = ir_sch.Split(loops[1], {4, 5}); + CHECK_EQ(splited_loops.size(), 2U); + auto new_rf_tensor = ir_sch.FactorizeReduction(splited_loops[0], 1); + auto* new_rf_tensor_ref = new_rf_tensor.As(); + CHECK(new_rf_tensor_ref); + CHECK(new_rf_tensor_ref->buffer.defined()); + func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer); + func[0]->PrepareBufferCastExprs(); + std::string origin = utils::GetStreamCnt(func[0]); + LOG(INFO) << origin; + EXPECT_EQ(origin, utils::Trim(R"ROC( +function test_factorize_reduction (_A, _B) +{ + ScheduleBlock(root) + { + { + serial for (i, 0, 3) + { + serial for (j0, 0, 4) + { + ScheduleBlock(B_rf__reduce_init) + { + vj0, i0_0 = axis.bind(j0, i) + B_rf__reduce_init[i0_0, vj0] = 0.00000000f + } + serial for (j0_0, 0, 5) + { + ScheduleBlock(B_rf) + { + vj0, i0_0, vj0_0 = axis.bind(j0, i, j0_0) + B_rf[i0_0, vj0] = (B_rf[i0_0, vj0] + A[i0_0, ((5 * vj0) + vj0_0)]) + } + } + } + } + serial for (i, 0, 3) + { + ScheduleBlock(B__reduce_init) + { + i0_0 = axis.bind(i) + B__reduce_init[i0_0] = 0.00000000f + } + serial for (j0, 0, 4) + { + ScheduleBlock(B) + { + vj0, i0_0 = axis.bind(j0, i) + B[i0_0] = (B[i0_0] + B_rf[i0_0, vj0]) + } + } + } + } + } +} +)ROC")); +} + TEST(IrSchedule, compute_inline1) { Context::Global().ResetNameId(); Expr M(32); diff --git a/paddle/cinn/ir/schedule/factorize_reduction.h b/paddle/cinn/ir/schedule/factorize_reduction.h new file mode 100644 index 0000000000000..0973d123fd40c --- /dev/null +++ b/paddle/cinn/ir/schedule/factorize_reduction.h @@ -0,0 +1,408 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Used in FactorizeReduction + +#pragma once +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/ir/utils/ir_copy.h" +#include "paddle/cinn/lang/compute.h" +#include "paddle/cinn/optim/replace_var_with_expr.h" +#include "paddle/cinn/utils/error.h" + +namespace cinn { +namespace ir { + +// Create the new Reduction-Factorized tensor, +// only used for FactorizeReduction schedule primitive. +Tensor CreateRFTensor(const Tensor& original_tensor, + const Expr& rf_loop, + int rf_axis) { + std::string name = original_tensor->name + "_rf"; + std::vector new_shape = original_tensor->shape; + new_shape.insert(new_shape.begin() + rf_axis, rf_loop.As()->extent); + Tensor rf_tensor = _Tensor_::Make(name, + original_tensor->type(), + new_shape, + new_shape, + original_tensor->operation, + original_tensor->reduce_axis); + rf_tensor->WithBuffer("global", name, original_tensor->type()); + return rf_tensor; +} + +// Base class to create a new reduce block, +// only used for FactorizeReduction schedule primitive. +class ReduceBlockCreater { + public: + ReduceBlockCreater(const Expr& original_block, + const std::vector& original_loops, + const Expr& rf_loop, + const Expr& original_update_stmt, + const ir::Tensor& rf_tensor, + bool is_rf_block) + : original_block_(original_block), + original_loops_(original_loops), + rf_loop_(rf_loop), + original_update_stmt_(original_update_stmt), + rf_tensor_(rf_tensor), + is_rf_block_(is_rf_block) { + const ScheduleBlockRealize* block_real = + original_block_.As(); + CHECK_NOTNULL(block_real); + num_block_iters_ = block_real->iter_values.size(); + } + + void CreateBlock() { + CreateRFIter(); + for (int i = 0; i < num_block_iters_; ++i) { + CreateNormalIter(i); + } + CreateUpdateStmt(); + + std::string new_update_block_name = + original_block_.As() + ->schedule_block.As() + ->name; + if (is_rf_block_) { + new_update_block_name += "_rf"; + } + std::string new_init_block_name = + ir::GenReduceInitTensorNameOf(new_update_block_name); + VLOG(5) << "new_init_block_name = " << new_init_block_name; + + Expr init_value = rf_tensor_->GetReduceInitVal(); + const std::vector& domain = rf_tensor_->domain_without_reduce_axis(); + ir::Tensor init_tensor = lang::Compute( + domain, + [=](const std::vector& axis) { return init_value; }, + new_init_block_name); + init_tensor->Bind(rf_tensor_->buffer); + Expr init_stmt = ir::Store::Make( + init_tensor, init_value, new_update_stmt_.As()->indices); + new_init_sch_block_ = ScheduleBlock::Make( + new_init_iter_vars_, {}, {}, new_init_block_name, init_stmt); + new_init_block_realize_ = + ScheduleBlockRealize::Make(new_init_iter_values_, new_init_sch_block_); + + new_update_sch_block_ = ScheduleBlock::Make( + new_iter_vars_, {}, {}, new_update_block_name, new_update_stmt_); + new_update_block_realize_ = + ScheduleBlockRealize::Make(new_iter_values_, new_update_sch_block_); + VLOG(4) << "new_update_block_realize:\n" << new_update_block_realize_; + } + + Expr CreateLoops() { + int num_loops = original_loops_.size(); + std::vector new_loops(num_loops); + Expr body = new_update_block_realize_; + bool has_add_init_block = false; + for (int i = num_loops - 1; i >= 0; --i) { + bool is_spatial_loop = + new_spatial_loop_var_names_.count( + original_loops_[i].As()->loop_var->name) > 0; + bool is_rf_loop = rf_loop_.As()->loop_var->name == + original_loops_[i].As()->loop_var->name; + // Skip non rf reduction loops of write back block. + if (!is_rf_block_ && !is_spatial_loop && !is_rf_loop) { + continue; + } + // Add reduce init block. + if (!has_add_init_block && is_spatial_loop) { + body = Block::Make({new_init_block_realize_, body}); + has_add_init_block = true; + } + // Add loops + Var loop_var = ir_utils::IRCopy(original_loops_[i].As()->loop_var); + Expr min = ir_utils::IRCopy(original_loops_[i].As()->min); + Expr extent = ir_utils::IRCopy(original_loops_[i].As()->extent); + body = For::Make(loop_var, + min, + extent, + original_loops_[i].As()->for_type(), + original_loops_[i].As()->device_api, + body, + original_loops_[i].As()->vectorize_info(), + original_loops_[i].As()->bind_info()); + VLOG(5) << "new body:\n" << body; + } + VLOG(4) << "new loop nest:\n" << body; + return body; + } + + private: + virtual void CreateRFIter() = 0; + virtual void CreateNormalIter(int idx) = 0; + virtual void CreateUpdateStmt() = 0; + + public: + Var rf_var_; + std::vector rf_tensor_access_indices_; + + protected: + const Expr& original_block_; + const std::vector& original_loops_; + const Expr& rf_loop_; + const Expr& original_update_stmt_; + const ir::Tensor& rf_tensor_; + std::map original_indice2new_expr_; + int num_block_iters_; + bool is_rf_block_; + + std::vector new_iter_vars_; + std::vector new_iter_values_; + std::vector new_init_iter_vars_; + std::vector new_init_iter_values_; + std::unordered_set new_spatial_loop_var_names_; + Expr new_update_stmt_; + + Expr new_update_sch_block_; + Expr new_update_block_realize_; + Expr new_init_sch_block_; + Expr new_init_block_realize_; +}; + +// Implement class for building Reduction-Factorized block, +// only used for FactorizeReduction schedule primitive. +class RFBlockCreater : public ReduceBlockCreater { + public: + RFBlockCreater(const Expr& original_block, + const std::vector& original_loops, + const Expr& rf_loop, + const Expr& original_update_stmt, + const ir::Tensor& rf_tensor, + const std::map& var2loops, + int rf_axis) + : ReduceBlockCreater(original_block, + original_loops, + rf_loop, + original_update_stmt, + rf_tensor, + true), + var2loops_(var2loops), + rf_axis_(rf_axis) {} + + private: + void CreateRFIter() override { + std::string loop_var_name = rf_loop_.As()->loop_var->name; + std::string rf_var_name = "v" + loop_var_name; + rf_var_ = Var(rf_loop_.As()->min, + rf_loop_.As()->extent, + rf_var_name, + /* is_reduce = */ false); + loop_var2block_iters_[rf_loop_.As()->loop_var] = rf_var_; + new_iter_vars_.push_back(rf_var_); + new_iter_values_.push_back(rf_loop_.As()->loop_var); + new_init_iter_vars_.push_back(rf_var_); + new_init_iter_values_.push_back(rf_loop_.As()->loop_var); + new_spatial_loop_var_names_.insert(rf_loop_.As()->loop_var->name); + VLOG(4) << "create new_rf_var = " << rf_var_ + << ", with iter value = " << new_iter_values_.back(); + } + + void CreateNormalIter(int idx) override { + Var original_iter_var = original_block_.As() + ->schedule_block.As() + ->iter_vars[idx]; + Expr original_iter_value = + original_block_.As()->iter_values[idx]; + // The original iter is either a spatial iter, or a reduction iter that + // doesn't touch the rf loop. In this case reuse the old iter var and its + // corresponding iter value. + if (!original_iter_var->is_reduce_axis) { + new_iter_vars_.push_back(original_iter_var); + new_iter_values_.push_back(original_iter_value); + new_init_iter_vars_.push_back(original_iter_var); + new_init_iter_values_.push_back(original_iter_value); + ir_utils::CollectIRNodesWithoutTensor( + original_iter_value, [&](const Expr* x) { + if (x->as_var()) { + new_spatial_loop_var_names_.insert(x->as_var()->name); + } + return false; + }); + return; + } else if (!ContainVar({original_iter_value}, + rf_loop_.As()->loop_var->name)) { + new_iter_vars_.push_back(original_iter_var); + new_iter_values_.push_back(original_iter_value); + return; + } + CHECK(original_iter_var->is_reduce_axis); + + // This iter is a reduction iter and touches the rfactor loop. So we try to + // create a new iter for each loop var that appear in the original iter + // value. + std::vector vars_in_original_iter_values; + ir_utils::CollectIRNodesWithoutTensor( + original_iter_value, [&](const Expr* x) { + if (x->as_var()) { + vars_in_original_iter_values.push_back(x->as_var_ref()); + } + return false; + }); + for (const Var& loop_var : vars_in_original_iter_values) { + if (var2loops_.count(loop_var) == 0) { + continue; + } + Expr loop = var2loops_.at(loop_var); + if (loop_var2block_iters_.count(loop_var) == 0) { + Var new_iter_var(loop.As()->min, + loop.As()->extent, + "v" + loop_var->name, + /* is_reduce = */ true); + new_iter_vars_.push_back(new_iter_var); + new_iter_values_.emplace_back(loop_var); + loop_var2block_iters_[loop_var] = new_iter_var; + } + } + // Substitute the original iter values with new iter vars, + // and store the new iter values in original_indice2new_expr_, + // it will be used in Load/Store indices. + Expr new_iters = ir_utils::IRCopy(original_iter_value); + ReplaceExpr(&new_iters, loop_var2block_iters_); + original_indice2new_expr_[original_iter_var] = new_iters; + VLOG(4) << "original_indice2new_expr_[" << original_iter_var + << "] = " << new_iters; + } + + void CreateUpdateStmt() override { + rf_tensor_access_indices_ = original_update_stmt_.As()->indices; + rf_tensor_access_indices_.insert( + rf_tensor_access_indices_.begin() + rf_axis_, rf_var_); + Expr original_store_body = original_update_stmt_.As()->value; + Expr new_store_body = ir_utils::IRCopy(original_store_body); +#define REPLACE_RF_TENSOR(Op) \ + if (new_store_body.As()) { \ + auto* node = new_store_body.As(); \ + CHECK(node); \ + auto& operand = node->a(); \ + operand = Load::Make(rf_tensor_, rf_tensor_access_indices_); \ + } + + REPLACE_RF_TENSOR(Add) + REPLACE_RF_TENSOR(Mul) + REPLACE_RF_TENSOR(Max) + REPLACE_RF_TENSOR(Min) +#undef REPLACE_RF_TENSOR + + new_update_stmt_ = + ir::Store::Make(rf_tensor_, new_store_body, rf_tensor_access_indices_); + ReplaceExpr(&new_update_stmt_, original_indice2new_expr_); + VLOG(4) << "new_update_stmt of rf block: \n" << new_update_stmt_; + } + + private: + const std::map& var2loops_; + int rf_axis_; + + std::map loop_var2block_iters_; +}; + +// Implement class for building Writing-Back block, +// only used for FactorizeReduction schedule primitive. +class RBBlockCreater : public ReduceBlockCreater { + public: + RBBlockCreater(const Expr& original_block, + const std::vector& original_loops, + const Expr& rf_loop, + const Expr& original_update_stmt, + const ir::Tensor& rf_tensor, + const std::vector& rf_tensor_access_indices, + const Var& rf_block_rf_iter_var) + : ReduceBlockCreater(original_block, + original_loops, + rf_loop, + original_update_stmt, + rf_tensor, + false), + rf_tensor_access_indices_(rf_tensor_access_indices), + rf_block_rf_iter_var_(rf_block_rf_iter_var) {} + + private: + void CreateRFIter() override { + std::string loop_var_name = rf_loop_.As()->loop_var->name; + std::string rf_var_name = "v" + loop_var_name; + rf_var_ = Var(rf_loop_.As()->min, + rf_loop_.As()->extent, + rf_var_name, + /* is_reduce = */ true); + new_iter_vars_.push_back(rf_var_); + new_iter_values_.push_back(rf_loop_.As()->loop_var); + original_indice2new_expr_[rf_block_rf_iter_var_] = Expr(rf_var_); + VLOG(4) << "create new_rf_var = " << rf_var_ + << ", with iter value = " << new_iter_values_.back(); + } + + void CreateNormalIter(int idx) override { + Var original_iter_var = original_block_.As() + ->schedule_block.As() + ->iter_vars[idx]; + Expr original_iter_value = + original_block_.As()->iter_values[idx]; + if (!original_iter_var->is_reduce_axis) { + new_iter_vars_.push_back(original_iter_var); + new_iter_values_.push_back(original_iter_value); + new_init_iter_vars_.push_back(original_iter_var); + new_init_iter_values_.push_back(original_iter_value); + ir_utils::CollectIRNodesWithoutTensor( + original_iter_value, [&](const Expr* x) { + if (x->as_var()) { + new_spatial_loop_var_names_.insert(x->as_var()->name); + } + return false; + }); + // original_indice2new_expr_[original_iter_var] = new_iter_vars_.back(); + VLOG(4) << "create new iter var = " << new_iter_vars_.back() + << ", with iter value = " << new_iter_values_.back(); + } + } + + void CreateUpdateStmt() override { + Expr original_store_body = original_update_stmt_.As()->value; + Expr new_store_body = ir_utils::IRCopy(original_store_body); +#define REPLACE_RF_TENSOR(Op) \ + if (new_store_body.As()) { \ + auto* node = new_store_body.As(); \ + CHECK(node); \ + auto& operand = node->b(); \ + operand = Load::Make(rf_tensor_, rf_tensor_access_indices_); \ + } + + REPLACE_RF_TENSOR(Add) + REPLACE_RF_TENSOR(Mul) + REPLACE_RF_TENSOR(Max) + REPLACE_RF_TENSOR(Min) +#undef REPLACE_RF_TENSOR + + Expr original_store_tensor = original_update_stmt_.As()->tensor; + std::vector original_store_indices = + original_update_stmt_.As()->indices; + new_update_stmt_ = ir::Store::Make( + original_store_tensor, new_store_body, original_store_indices); + ReplaceExpr(&new_update_stmt_, original_indice2new_expr_); + VLOG(4) << "new_update_stmt of write back block: \n" << new_update_stmt_; + } + + private: + const std::vector& rf_tensor_access_indices_; + const Var& rf_block_rf_iter_var_; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/schedule/ir_schedule.cc b/paddle/cinn/ir/schedule/ir_schedule.cc index f17e17b73019d..24f97b6e03d1e 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.cc +++ b/paddle/cinn/ir/schedule/ir_schedule.cc @@ -33,6 +33,7 @@ #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/ir_visitor.h" #include "paddle/cinn/ir/op/ir_operators.h" +#include "paddle/cinn/ir/schedule/factorize_reduction.h" #include "paddle/cinn/ir/schedule/ir_schedule_error.h" #include "paddle/cinn/ir/schedule/ir_schedule_util.h" #include "paddle/cinn/ir/utils/ir_copy.h" @@ -120,6 +121,7 @@ class ScheduleImpl { void ReverseComputeInline(const Expr& schedule_block); void Bind(const Expr& loop, const std::string& thread_axis); Expr Rfactor(const Expr& rf_loop, int rf_axis); + Expr FactorizeReduction(const Expr& rf_loop, int rf_axis); Expr AddUnitLoop(const Expr& block) const; void Annotate(const Expr& block, const std::string& key, const attr_t& value); void Unannotate(Expr& block, const std::string& key); // NOLINT @@ -717,6 +719,79 @@ Expr ScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) { return rf_create.CreateRfAllStmts(); } +Expr ScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { + std::string primitive = "FactorizeReduction"; + // Get child block of the rf_loop and check. + std::vector blocks = GetChildBlocks(rf_loop); + if (blocks.size() != 1) { + std::ostringstream os; + os << "The rf_loop is required to have only one child block, but got " + << blocks.size() << std::endl; + throw IRScheduleErrorHandler(primitive, os.str(), this->module_expr_); + } + Expr original_block = blocks.at(0); + Expr root_block = GetRootBlock(original_block); + // TODO(BiynXu): Add CheckReductionBlock() + + // Collect the loops of the block. + // Construct a map from loop var names to corresponding loops. + std::vector original_loops = this->GetLoops(original_block); + CHECK_GT(original_loops.size(), 0); + VLOG(3) << "before FactorizeReduction, original computational body of the " + "reduction is:\n" + << original_loops[0]; + std::map var2loops; + for (const Expr& loop : original_loops) { + var2loops[loop.As()->loop_var] = loop; + } + + // Get original stmt of reduction update and original store tensor. + Expr original_update_body = original_block.As() + ->schedule_block.As() + ->body; + Expr original_update_stmt; + CHECK(original_update_body.As() || original_update_body.As()); + if (original_update_body.As()) { + CHECK_EQ(original_update_body.As()->stmts.size(), 1); + original_update_stmt = original_update_body.As()->stmts[0]; + } else if (original_update_body.As()) { + original_update_stmt = original_update_body; + } + Tensor original_tensor = + original_update_stmt.As()->tensor.as_tensor_ref(); + + // Create new blocks and loops. + Tensor rf_tensor = CreateRFTensor(original_tensor, rf_loop, rf_axis); + RFBlockCreater rf_block_creater(original_block, + original_loops, + rf_loop, + original_update_stmt, + rf_tensor, + var2loops, + rf_axis); + rf_block_creater.CreateBlock(); + RBBlockCreater wb_block_creater(original_block, + original_loops, + rf_loop, + original_update_stmt, + rf_tensor, + rf_block_creater.rf_tensor_access_indices_, + rf_block_creater.rf_var_); + wb_block_creater.CreateBlock(); + + Expr rf_body = rf_block_creater.CreateLoops(); + Expr wb_body = wb_block_creater.CreateLoops(); + + Expr new_computational_body = Block::Make({rf_body, wb_body}); + + // Replace and update the AST. + this->Replace(original_loops[0], new_computational_body); + VLOG(3) << "After FactorizeReduction, new computational body of the " + "reduction is:\n" + << new_computational_body; + return rf_tensor; +} + struct CacheReadRewriter : public ir::IRMutator<> { public: static Expr Rewrite(const Expr& root, CacheBlockInfo* info) { @@ -2647,6 +2722,15 @@ Expr IRSchedule::Rfactor(const Expr& rf_loop, int rf_axis) { return result; } +Expr IRSchedule::FactorizeReduction(const Expr& rf_loop, int rf_axis) { + auto result = impl_->FactorizeReduction(rf_loop, rf_axis); + trace_.Append(ScheduleDesc::Step("FactorizeReduction", + {{"rf_loop", std::vector({rf_loop})}}, + {{"rf_axis", rf_axis}}, + {result})); + return result; +} + void IRSchedule::Annotate(const Expr& block, const std::string& key, const attr_t& value) { diff --git a/paddle/cinn/ir/schedule/ir_schedule.h b/paddle/cinn/ir/schedule/ir_schedule.h index ce341c502b1fb..4c5fc1d10f1b6 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.h +++ b/paddle/cinn/ir/schedule/ir_schedule.h @@ -381,6 +381,46 @@ class IRSchedule { */ Expr Rfactor(const Expr& rf_loop, int rf_axis); + /** + * \brief Factorize the reduction block by the given loop. The block will be + * split into two blocks: reduction-factorized block and write-back block. + * @param rf_loop the reduce loop to be factorized. + * @param rf_axis The position where the new dimension is placed in the new rf + * tensor. + * @return The new created rf tensor. + * + * For example, input the block: + * \code + * for (i, 0, 10) // serial loop + * B_init[i] = 0 + * for (j, 0, 20) // reduce loop + * for (k, 0, 30) // reduce loop + * B[i] = B[i] + A[i, j, k] + * \endcode + * + * If the rf loop is j and rf_axis is 0, the transformation is + * divided into 2 steps: + * 1. get the rf block where the reduce loop j is transformed to the + * serial loop with no accumalation and a new rf tensor is created. + * The axis j will be placed in the rf_axis of the new rf_tensor. + * The rf_block is as follows: + * \code + * for (i, 0, 10) // serial loop + * for (j, 0, 20) // rf loop j is transformed to the serial loop + * rf_B_init[j, i] = 0 + * for (k, 0, 30) // reduce loop. + * rf_B[j, i] = rf_B[j, i] + A[i, j, k] + * \endcode + * 2. do reduction of the rf loop j to get the final result block: + * \code + * for (i, 0, 10) // serial loop + * B_init[i] = 0 + * for (j, 0, 20) // rf reduction loop + * B[i] = B[i] + rf_B[j, i] + * \endcode + */ + Expr FactorizeReduction(const Expr& rf_loop, int rf_axis); + /*! * \brief Annotate a block with a key-value pair to set as its attribute * \param block The block to be annotated diff --git a/paddle/cinn/ir/schedule/ir_schedule_util.cc b/paddle/cinn/ir/schedule/ir_schedule_util.cc index 7144e1484a58c..7a2daa3106612 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_util.cc +++ b/paddle/cinn/ir/schedule/ir_schedule_util.cc @@ -221,6 +221,14 @@ void ReplaceExpr(Expr* source, return; } +void ReplaceExpr(Expr* source, + const std::map& replacing_map) { + if (replacing_map.empty()) return; + MappingVarToExprMutator mapper(replacing_map); + mapper(source); + return; +} + std::vector ValidateFactors(const std::vector& factors, int total_extent, const ModuleExpr& module_expr) { diff --git a/paddle/cinn/ir/schedule/ir_schedule_util.h b/paddle/cinn/ir/schedule/ir_schedule_util.h index 50515e5f3cfa9..9c9418b4d577e 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_util.h +++ b/paddle/cinn/ir/schedule/ir_schedule_util.h @@ -193,7 +193,7 @@ Tensor GetReadTensor(const Expr& block, int index); int GetLoopExtent(const Expr& loop); /** - * \brief Given a vector of Exors, return whether they contain a var with + * \brief Given a vector of Exprs, return whether they contain a var with * specific name. * @param exprs The given vector of Exprs * @param var_name The name of specific var @@ -241,6 +241,15 @@ void ReplaceExpr(Expr* source, const std::vector& replaced, const std::vector& candidates); +/** + * Replace Vars in replaced to Exprs in candidates in source. + * @param source The Expr we will implement the change. + * @param replacing_map The one-to-one corresponded Vars -> Exprs to be + * replaced. + */ +void ReplaceExpr(Expr* source, + const std::map& replacing_map); + /** * Validate the factors param of Split. We will check if factors are validate * and change -1 to positive integer. diff --git a/paddle/cinn/ir/schedule/schedule_desc.cc b/paddle/cinn/ir/schedule/schedule_desc.cc index a3ef7e72a1bc9..e0d5f4ab21701 100644 --- a/paddle/cinn/ir/schedule/schedule_desc.cc +++ b/paddle/cinn/ir/schedule/schedule_desc.cc @@ -474,6 +474,12 @@ CINN_BUILD_STEP_KIND(Rfactor) .SetApplyFn( APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Rfactor))); +CINN_BUILD_STEP_KIND(FactorizeReduction) + .Inputs({"rf_loop"}) + .Attrs({"rf_axis"}) + .SetApplyFn(APPLY_FUNC_UNIFORM( + FREE_FUNCTION_CONVERTER(&IRSchedule::FactorizeReduction))); + CINN_BUILD_STEP_KIND(MergeExprs) .SetApplyFn( APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::MergeExprs)));