diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index f1a8a1ae06661..77ca33f2450cd 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -47,6 +47,8 @@ bool whole_kernel_cse(IRNode *root); bool extract_constant(IRNode *root, const CompileConfig &config); bool unreachable_code_elimination(IRNode *root); bool loop_invariant_code_motion(IRNode *root, const CompileConfig &config); +bool cache_loop_invariant_global_vars(IRNode *root, + const CompileConfig &config); void full_simplify(IRNode *root, const CompileConfig &config, const FullSimplifyPass::Args &args); diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index 704e25a480df1..8868f7a642700 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -28,6 +28,7 @@ struct CompileConfig { bool lower_access; bool simplify_after_lower_access; bool move_loop_invariant_outside_if; + bool cache_loop_invariant_global_vars{true}; bool demote_dense_struct_fors; bool advanced_optimization; bool constant_folding; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 3795ba380aebb..0c70a7170b1e8 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -166,6 +166,8 @@ void export_lang(py::module &m) { .def_readwrite("lower_access", &CompileConfig::lower_access) .def_readwrite("move_loop_invariant_outside_if", &CompileConfig::move_loop_invariant_outside_if) + .def_readwrite("cache_loop_invariant_global_vars", + &CompileConfig::cache_loop_invariant_global_vars) .def_readwrite("default_cpu_block_dim", &CompileConfig::default_cpu_block_dim) .def_readwrite("cpu_block_dim_adaptive", diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp new file mode 100644 index 0000000000000..5310e5d94c8a0 --- /dev/null +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -0,0 +1,180 @@ +#include "taichi/transforms/loop_invariant_detector.h" +#include "taichi/ir/analysis.h" + +namespace taichi::lang { + +class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { + public: + using LoopInvariantDetector::visit; + + enum class CacheStatus { + None = 0, + Read = 1, + Write = 2, + ReadWrite = 3, + }; + + typedef std::unordered_map> + CacheMap; + std::stack cached_maps; + + DelayedIRModifier modifier; + std::unordered_map loop_unique_ptr_; + std::unordered_map loop_unique_arr_ptr_; + + OffloadedStmt *current_offloaded; + + explicit CacheLoopInvariantGlobalVars(const CompileConfig &config) + : LoopInvariantDetector(config) { + } + + void visit(OffloadedStmt *stmt) override { + if (stmt->task_type == OffloadedTaskType::range_for || + stmt->task_type == OffloadedTaskType::mesh_for || + stmt->task_type == OffloadedTaskType::struct_for) { + auto uniquely_accessed_pointers = + irpass::analysis::gather_uniquely_accessed_pointers(stmt); + loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first); + loop_unique_arr_ptr_ = std::move(uniquely_accessed_pointers.second); + } + current_offloaded = stmt; + // We don't need to visit TLS/BLS prologues/epilogues. + if (stmt->body) { + if (stmt->task_type == OffloadedStmt::TaskType::range_for || + stmt->task_type == OffloadedTaskType::mesh_for || + stmt->task_type == OffloadedStmt::TaskType::struct_for) + visit_loop(stmt->body.get()); + else + stmt->body->accept(this); + } + current_offloaded = nullptr; + } + + bool is_offload_unique(Stmt *stmt) { + if (current_offloaded->task_type == OffloadedTaskType::serial) { + return true; + } + if (auto global_ptr = stmt->cast()) { + auto snode = global_ptr->snode; + if (loop_unique_ptr_[snode] == nullptr || + loop_unique_ptr_[snode]->indices.empty()) { + // not uniquely accessed + return false; + } + if (current_offloaded->mem_access_opt.has_flag( + snode, SNodeAccessFlag::block_local) || + current_offloaded->mem_access_opt.has_flag( + snode, SNodeAccessFlag::mesh_local)) { + // BLS does not support write access yet so we keep atomic_adds. + return false; + } + return true; + } else if (stmt->is()) { + ExternalPtrStmt *dest_ptr = stmt->as(); + if (dest_ptr->indices.empty()) { + return false; + } + ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as(); + int arg_id = arg_load_stmt->arg_id; + if (loop_unique_arr_ptr_[arg_id] == nullptr) { + // Not loop unique + return false; + } + return true; + // TODO: Is BLS / Mem Access Opt a thing for any_arr? + } + return false; + } + + void visit_loop(Block *body) override { + cached_maps.emplace(); + LoopInvariantDetector::visit_loop(body); + cached_maps.pop(); + } + + void add_writeback(AllocaStmt *alloca_stmt, Stmt *global_var) { + auto final_value = std::make_unique(alloca_stmt); + auto global_store = + std::make_unique(global_var, final_value.get()); + modifier.insert_after(current_loop_stmt(), std::move(global_store)); + modifier.insert_after(current_loop_stmt(), std::move(final_value)); + } + + void set_init_value(AllocaStmt *alloca_stmt, Stmt *global_var) { + auto new_global_load = std::make_unique(global_var); + auto local_store = + std::make_unique(alloca_stmt, new_global_load.get()); + modifier.insert_before(current_loop_stmt(), std::move(new_global_load)); + modifier.insert_before(current_loop_stmt(), std::move(local_store)); + } + + AllocaStmt *cache_global_to_local(Stmt *dest, CacheStatus status) { + if (auto &[cached_status, alloca_stmt] = cached_maps.top()[dest]; + cached_status != CacheStatus::None) { + // The global variable has already been cached. + if (cached_status == CacheStatus::Read && status == CacheStatus::Write) { + add_writeback(alloca_stmt, dest); + cached_status = CacheStatus::ReadWrite; + } + return alloca_stmt; + } + auto alloca_unique = + std::make_unique(dest->ret_type.ptr_removed()); + auto alloca_stmt = alloca_unique.get(); + modifier.insert_before(current_loop_stmt(), std::move(alloca_unique)); + if (status == CacheStatus::Read) { + set_init_value(alloca_stmt, dest); + } else if (status == CacheStatus::Write) { + add_writeback(alloca_stmt, dest); + } + cached_maps.top()[dest] = {status, alloca_stmt}; + return alloca_stmt; + } + + void visit(GlobalLoadStmt *stmt) override { + if (is_offload_unique(stmt->src) && + is_operand_loop_invariant(stmt->src, stmt->parent)) { + auto alloca_stmt = cache_global_to_local(stmt->src, CacheStatus::Read); + auto local_load = std::make_unique(alloca_stmt); + stmt->replace_usages_with(local_load.get()); + modifier.insert_before(stmt, std::move(local_load)); + modifier.erase(stmt); + } + } + + void visit(GlobalStoreStmt *stmt) override { + if (is_offload_unique(stmt->dest) && + is_operand_loop_invariant(stmt->dest, stmt->parent)) { + auto alloca_stmt = cache_global_to_local(stmt->dest, CacheStatus::Write); + auto local_store = + std::make_unique(alloca_stmt, stmt->val); + stmt->replace_usages_with(local_store.get()); + modifier.insert_before(stmt, std::move(local_store)); + modifier.erase(stmt); + } + } + + static bool run(IRNode *node, const CompileConfig &config) { + bool modified = false; + + while (true) { + CacheLoopInvariantGlobalVars eliminator(config); + node->accept(&eliminator); + if (eliminator.modifier.modify_ir()) + modified = true; + else + break; + }; + + return modified; + } +}; + +namespace irpass { +bool cache_loop_invariant_global_vars(IRNode *root, + const CompileConfig &config) { + TI_AUTO_PROF; + return CacheLoopInvariantGlobalVars::run(root, config); +} +} // namespace irpass +} // namespace taichi::lang diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 0edd297e3b773..58d00e0c2332f 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -186,6 +186,10 @@ void offload_to_executable(IRNode *ir, irpass::demote_atomics(ir, config); print("Atomics demoted I"); irpass::analysis::verify(ir); + if (config.cache_loop_invariant_global_vars) { + irpass::cache_loop_invariant_global_vars(ir, config); + print("Cache loop-invariant global vars"); + } if (config.demote_dense_struct_fors) { irpass::demote_dense_struct_fors(ir, config.packed); @@ -246,6 +250,9 @@ void offload_to_executable(IRNode *ir, irpass::analysis::verify(ir); if (lower_global_access) { + irpass::full_simplify(ir, config, + {false, /*autodiff_enabled*/ false, kernel->program}); + print("Simplified before lower access"); irpass::lower_access(ir, config, {kernel->no_activate, true}); print("Access lowered"); irpass::analysis::verify(ir); diff --git a/taichi/transforms/loop_invariant_code_motion.cpp b/taichi/transforms/loop_invariant_code_motion.cpp index 49f735ff2ea7b..8e9fc4c4f9a11 100644 --- a/taichi/transforms/loop_invariant_code_motion.cpp +++ b/taichi/transforms/loop_invariant_code_motion.cpp @@ -1,147 +1,68 @@ -#include "taichi/ir/ir.h" -#include "taichi/ir/statements.h" -#include "taichi/ir/transforms.h" -#include "taichi/ir/visitors.h" -#include "taichi/system/profiler.h" - -#include +#include "taichi/transforms/loop_invariant_detector.h" TLANG_NAMESPACE_BEGIN -class LoopInvariantCodeMotion : public BasicStmtVisitor { +class LoopInvariantCodeMotion : public LoopInvariantDetector { public: - using BasicStmtVisitor::visit; - - std::stack loop_blocks; - - const CompileConfig &config; + using LoopInvariantDetector::visit; DelayedIRModifier modifier; explicit LoopInvariantCodeMotion(const CompileConfig &config) - : config(config) { - allow_undefined_visitor = true; - } - - bool stmt_can_be_moved(Stmt *stmt) { - if (loop_blocks.size() <= 1 || (!config.move_loop_invariant_outside_if && - stmt->parent != loop_blocks.top())) - return false; - - bool can_be_moved = true; - - Block *current_scope = stmt->parent; - - for (Stmt *operand : stmt->get_operands()) { - if (operand->parent == current_scope) { - // This statement has an operand that is in the current scope, - // so it can not be moved out of the scope. - can_be_moved = false; - break; - } - if (config.move_loop_invariant_outside_if && - stmt->parent != loop_blocks.top()) { - // If we enable moving code from a nested if block, we need to check - // visibility. Example: - // for i in range(10): - // a = x[0] - // if b: - // c = a + 1 - // Since we are moving statements outside the cloest for scope, - // We need to check the scope of the operand - Stmt *operand_parent = operand; - while (operand_parent && operand_parent->parent) { - operand_parent = operand_parent->parent->parent_stmt; - if (!operand_parent) - break; - // If the one of the parent of the operand is the top loop scope - // Then it will not be visible if we move it outside the top loop - // scope - if (operand_parent == loop_blocks.top()->parent_stmt) { - can_be_moved = false; - break; - } - } - if (!can_be_moved) - break; - } - } - - return can_be_moved; + : LoopInvariantDetector(config) { } void visit(BinaryOpStmt *stmt) override { - if (stmt_can_be_moved(stmt)) { + if (is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); - modifier.insert_before(stmt->parent->parent_stmt, std::move(replacement)); + modifier.insert_before(current_loop_stmt(), std::move(replacement)); modifier.erase(stmt); } } void visit(UnaryOpStmt *stmt) override { - if (stmt_can_be_moved(stmt)) { + if (is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); - modifier.insert_before(stmt->parent->parent_stmt, std::move(replacement)); + modifier.insert_before(current_loop_stmt(), std::move(replacement)); modifier.erase(stmt); } } - void visit(Block *stmt_list) override { - for (auto &stmt : stmt_list->statements) - stmt->accept(this); - } - - void visit_loop(Block *body) { - loop_blocks.push(body); - - body->accept(this); - - loop_blocks.pop(); - } - - void visit(RangeForStmt *stmt) override { - visit_loop(stmt->body.get()); - } + void visit(GlobalPtrStmt *stmt) override { + if (config.cache_loop_invariant_global_vars && + is_loop_invariant(stmt, stmt->parent)) { + auto replacement = stmt->clone(); + stmt->replace_usages_with(replacement.get()); - void visit(StructForStmt *stmt) override { - visit_loop(stmt->body.get()); + modifier.insert_before(current_loop_stmt(), std::move(replacement)); + modifier.erase(stmt); + } } - void visit(MeshForStmt *stmt) override { - visit_loop(stmt->body.get()); - } + void visit(ExternalPtrStmt *stmt) override { + if (config.cache_loop_invariant_global_vars && + is_loop_invariant(stmt, stmt->parent)) { + auto replacement = stmt->clone(); + stmt->replace_usages_with(replacement.get()); - void visit(WhileStmt *stmt) override { - visit_loop(stmt->body.get()); + modifier.insert_before(current_loop_stmt(), std::move(replacement)); + modifier.erase(stmt); + } } - void visit(OffloadedStmt *stmt) override { - if (stmt->tls_prologue) - stmt->tls_prologue->accept(this); - - if (stmt->mesh_prologue) - stmt->mesh_prologue->accept(this); - - if (stmt->bls_prologue) - stmt->bls_prologue->accept(this); + void visit(ArgLoadStmt *stmt) override { + if (config.cache_loop_invariant_global_vars && + is_loop_invariant(stmt, stmt->parent)) { + auto replacement = stmt->clone(); + stmt->replace_usages_with(replacement.get()); - if (stmt->body) { - if (stmt->task_type == OffloadedStmt::TaskType::range_for || - stmt->task_type == OffloadedStmt::TaskType::struct_for) - visit_loop(stmt->body.get()); - else - stmt->body->accept(this); + modifier.insert_before(current_loop_stmt(), std::move(replacement)); + modifier.erase(stmt); } - - if (stmt->bls_epilogue) - stmt->bls_epilogue->accept(this); - - if (stmt->tls_epilogue) - stmt->tls_epilogue->accept(this); } static bool run(IRNode *node, const CompileConfig &config) { diff --git a/taichi/transforms/loop_invariant_detector.h b/taichi/transforms/loop_invariant_detector.h new file mode 100644 index 0000000000000..d9ce963c505ca --- /dev/null +++ b/taichi/transforms/loop_invariant_detector.h @@ -0,0 +1,134 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" +#include "taichi/system/profiler.h" + +#include + +namespace taichi::lang { + +class LoopInvariantDetector : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + + std::stack loop_blocks; + + const CompileConfig &config; + + explicit LoopInvariantDetector(const CompileConfig &config) : config(config) { + allow_undefined_visitor = true; + } + + bool is_operand_loop_invariant_impl(Stmt *operand, Block *current_scope) { + if (operand->parent == current_scope) { + // This statement has an operand that is in the current scope, + // so it can not be moved out of the scope. + return false; + } + if (current_scope != loop_blocks.top()) { + // If we enable moving code from a nested if block, we need to check + // visibility. Example: + // for i in range(10): + // a = x[0] + // if b: + // c = a + 1 + // Since we are moving statements outside the closest for scope, + // We need to check the scope of the operand + Stmt *operand_parent = operand; + while (operand_parent->parent) { + operand_parent = operand_parent->parent->parent_stmt; + if (!operand_parent) + break; + // If the one of the current_scope of the operand is the top loop + // scope Then it will not be visible if we move it outside the top + // loop scope + if (operand_parent == loop_blocks.top()->parent_stmt) { + return false; + } + } + } + return true; + } + + bool is_operand_loop_invariant(Stmt *operand, Block *current_scope) { + if (loop_blocks.size() <= 1) + return false; + return is_operand_loop_invariant_impl(operand, current_scope); + } + + bool is_loop_invariant(Stmt *stmt, Block *current_scope) { + if (loop_blocks.size() <= 1 || (!config.move_loop_invariant_outside_if && + current_scope != loop_blocks.top())) + return false; + + bool is_invariant = true; + + for (Stmt *operand : stmt->get_operands()) { + is_invariant &= is_operand_loop_invariant_impl(operand, current_scope); + } + + return is_invariant; + } + + Stmt *current_loop_stmt() { + return loop_blocks.top()->parent_stmt; + } + + void visit(Block *stmt_list) override { + for (auto &stmt : stmt_list->statements) + stmt->accept(this); + } + + virtual void visit_loop(Block *body) { + loop_blocks.push(body); + + body->accept(this); + + loop_blocks.pop(); + } + + void visit(RangeForStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(StructForStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(MeshForStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(WhileStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(OffloadedStmt *stmt) override { + if (stmt->tls_prologue) + stmt->tls_prologue->accept(this); + + if (stmt->mesh_prologue) + stmt->mesh_prologue->accept(this); + + if (stmt->bls_prologue) + stmt->bls_prologue->accept(this); + + if (stmt->body) { + if (stmt->task_type == OffloadedStmt::TaskType::range_for || + stmt->task_type == OffloadedTaskType::mesh_for || + stmt->task_type == OffloadedStmt::TaskType::struct_for) + visit_loop(stmt->body.get()); + else + stmt->body->accept(this); + } + + if (stmt->bls_epilogue) + stmt->bls_epilogue->accept(this); + + if (stmt->tls_epilogue) + stmt->tls_epilogue->accept(this); + } +}; + +} // namespace taichi::lang