From 1e439e68d22dd34f10eaaf7653364ad07fd2af4c Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Thu, 15 Sep 2022 18:35:59 +0800 Subject: [PATCH 01/16] [opt] Cache loop-invariant global vars to local vars --- taichi/ir/transforms.h | 1 + .../cache_loop_invariant_global_vars.cpp | 96 +++++++++++++ taichi/transforms/compile_to_offloads.cpp | 3 + .../transforms/loop_invariant_code_motion.cpp | 126 ++---------------- taichi/transforms/loop_invariant_detector.h | 124 +++++++++++++++++ taichi/transforms/simplify.cpp | 2 + 6 files changed, 238 insertions(+), 114 deletions(-) create mode 100644 taichi/transforms/cache_loop_invariant_global_vars.cpp create mode 100644 taichi/transforms/loop_invariant_detector.h diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index f1a8a1ae06661..fba2f0c75660b 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -47,6 +47,7 @@ bool whole_kernel_cse(IRNode *root); bool extract_constant(IRNode *root, const CompileConfig &config); bool unreachable_code_elimination(IRNode *root); bool loop_invariant_code_motion(IRNode *root, const CompileConfig &config); +bool cache_loop_invariant_global_vars(IRNode *root, const CompileConfig &config); void full_simplify(IRNode *root, const CompileConfig &config, const FullSimplifyPass::Args &args); diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp new file mode 100644 index 0000000000000..107d2600332cf --- /dev/null +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -0,0 +1,96 @@ +#include "taichi/transforms/loop_invariant_detector.h" + +TLANG_NAMESPACE_BEGIN + +class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { + public: + using LoopInvariantDetector::visit; + + enum class CacheStatus { + ReadOnly = 0, + ReadWrite = 1 + }; + + std::unordered_map> cached_allocas; + + DelayedIRModifier modifier; + + explicit CacheLoopInvariantGlobalVars(const CompileConfig &config) + : LoopInvariantDetector(config) { + } + + void add_writeback(AllocaStmt *alloca_stmt, Stmt *stmt, Stmt *parent_stmt) { + auto final_value = std::make_unique(alloca_stmt); + auto global_store = std::make_unique(stmt, final_value.get()); + modifier.insert_after(parent_stmt, std::move(global_store)); + modifier.insert_after(parent_stmt, std::move(final_value)); + } + + AllocaStmt *cache_global_to_local(Stmt *stmt, Stmt *parent_stmt, CacheStatus status) { + if (auto &[cached, cached_status] = cached_allocas[stmt]; cached) { + if (cached_status == CacheStatus::ReadOnly && status == CacheStatus::ReadWrite) { + add_writeback(cached, stmt, parent_stmt); + cached_status = CacheStatus::ReadWrite; + } + return cached; + } + + auto alloca_unique = std::make_unique(stmt->ret_type.ptr_removed()); + auto alloca_stmt = alloca_unique.get(); + cached_allocas[stmt] = {alloca_stmt, status}; + auto new_global_load = std::make_unique(stmt); + auto local_store = std::make_unique(alloca_stmt, new_global_load.get()); + modifier.insert_before(parent_stmt, std::move(new_global_load)); + modifier.insert_before(parent_stmt, std::move(alloca_unique)); + modifier.insert_before(parent_stmt, std::move(local_store)); + + if (status == CacheStatus::ReadWrite) { + add_writeback(alloca_stmt, stmt, parent_stmt); + } + return alloca_stmt; + } + + void visit(GlobalLoadStmt *stmt) override { + if (is_loop_invariant(stmt->src, stmt->parent)) { + auto alloca_stmt = cache_global_to_local(stmt->src, stmt->parent->parent_stmt, CacheStatus::ReadOnly); + auto local_load = std::make_unique(alloca_stmt); + stmt->replace_usages_with(local_load.get()); + modifier.insert_before(stmt, std::move(local_load)); + modifier.erase(stmt); + } + } + + void visit(GlobalStoreStmt *stmt) override { + if (is_loop_invariant(stmt->dest, stmt->parent)) { + auto alloca_stmt = cache_global_to_local(stmt->dest, stmt->parent->parent_stmt, CacheStatus::ReadWrite); + auto local_store = std::make_unique(alloca_stmt, stmt->val); + stmt->replace_usages_with(local_store.get()); + modifier.insert_before(stmt, std::move(local_store)); + modifier.erase(stmt); + } + } + + static bool run(IRNode *node, const CompileConfig &config) { + bool modified = false; + + while (true) { + CacheLoopInvariantGlobalVars eliminator(config); + node->accept(&eliminator); + if (eliminator.modifier.modify_ir()) + modified = true; + else + break; + }; + + return modified; + } +}; + +namespace irpass { +bool cache_loop_invariant_global_vars(IRNode *root, const CompileConfig &config) { + TI_AUTO_PROF; + return CacheLoopInvariantGlobalVars::run(root, config); +} +} // namespace irpass + +TLANG_NAMESPACE_END diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 0edd297e3b773..f540ddb39622f 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -246,6 +246,9 @@ void offload_to_executable(IRNode *ir, irpass::analysis::verify(ir); if (lower_global_access) { + irpass::full_simplify( + ir, config, {false, /*autodiff_enabled*/ false, kernel->program}); + print("Simplified before lower access"); irpass::lower_access(ir, config, {kernel->no_activate, true}); print("Access lowered"); irpass::analysis::verify(ir); diff --git a/taichi/transforms/loop_invariant_code_motion.cpp b/taichi/transforms/loop_invariant_code_motion.cpp index 49f735ff2ea7b..7529f28a4f87f 100644 --- a/taichi/transforms/loop_invariant_code_motion.cpp +++ b/taichi/transforms/loop_invariant_code_motion.cpp @@ -1,77 +1,19 @@ -#include "taichi/ir/ir.h" -#include "taichi/ir/statements.h" -#include "taichi/ir/transforms.h" -#include "taichi/ir/visitors.h" -#include "taichi/system/profiler.h" - -#include +#include "taichi/transforms/loop_invariant_detector.h" TLANG_NAMESPACE_BEGIN -class LoopInvariantCodeMotion : public BasicStmtVisitor { +class LoopInvariantCodeMotion : public LoopInvariantDetector { public: - using BasicStmtVisitor::visit; - - std::stack loop_blocks; - - const CompileConfig &config; + using LoopInvariantDetector::visit; DelayedIRModifier modifier; explicit LoopInvariantCodeMotion(const CompileConfig &config) - : config(config) { - allow_undefined_visitor = true; - } - - bool stmt_can_be_moved(Stmt *stmt) { - if (loop_blocks.size() <= 1 || (!config.move_loop_invariant_outside_if && - stmt->parent != loop_blocks.top())) - return false; - - bool can_be_moved = true; - - Block *current_scope = stmt->parent; - - for (Stmt *operand : stmt->get_operands()) { - if (operand->parent == current_scope) { - // This statement has an operand that is in the current scope, - // so it can not be moved out of the scope. - can_be_moved = false; - break; - } - if (config.move_loop_invariant_outside_if && - stmt->parent != loop_blocks.top()) { - // If we enable moving code from a nested if block, we need to check - // visibility. Example: - // for i in range(10): - // a = x[0] - // if b: - // c = a + 1 - // Since we are moving statements outside the cloest for scope, - // We need to check the scope of the operand - Stmt *operand_parent = operand; - while (operand_parent && operand_parent->parent) { - operand_parent = operand_parent->parent->parent_stmt; - if (!operand_parent) - break; - // If the one of the parent of the operand is the top loop scope - // Then it will not be visible if we move it outside the top loop - // scope - if (operand_parent == loop_blocks.top()->parent_stmt) { - can_be_moved = false; - break; - } - } - if (!can_be_moved) - break; - } - } - - return can_be_moved; + : LoopInvariantDetector(config) { } void visit(BinaryOpStmt *stmt) override { - if (stmt_can_be_moved(stmt)) { + if (is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); @@ -81,7 +23,7 @@ class LoopInvariantCodeMotion : public BasicStmtVisitor { } void visit(UnaryOpStmt *stmt) override { - if (stmt_can_be_moved(stmt)) { + if (is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); @@ -90,58 +32,14 @@ class LoopInvariantCodeMotion : public BasicStmtVisitor { } } - void visit(Block *stmt_list) override { - for (auto &stmt : stmt_list->statements) - stmt->accept(this); - } - - void visit_loop(Block *body) { - loop_blocks.push(body); - - body->accept(this); - - loop_blocks.pop(); - } - - void visit(RangeForStmt *stmt) override { - visit_loop(stmt->body.get()); - } - - void visit(StructForStmt *stmt) override { - visit_loop(stmt->body.get()); - } - - void visit(MeshForStmt *stmt) override { - visit_loop(stmt->body.get()); - } - - void visit(WhileStmt *stmt) override { - visit_loop(stmt->body.get()); - } - - void visit(OffloadedStmt *stmt) override { - if (stmt->tls_prologue) - stmt->tls_prologue->accept(this); - - if (stmt->mesh_prologue) - stmt->mesh_prologue->accept(this); - - if (stmt->bls_prologue) - stmt->bls_prologue->accept(this); + void visit(GlobalPtrStmt *stmt) override { + if (is_loop_invariant(stmt, stmt->parent)) { + auto replacement = stmt->clone(); + stmt->replace_usages_with(replacement.get()); - if (stmt->body) { - if (stmt->task_type == OffloadedStmt::TaskType::range_for || - stmt->task_type == OffloadedStmt::TaskType::struct_for) - visit_loop(stmt->body.get()); - else - stmt->body->accept(this); + modifier.insert_before(stmt->parent->parent_stmt, std::move(replacement)); + modifier.erase(stmt); } - - if (stmt->bls_epilogue) - stmt->bls_epilogue->accept(this); - - if (stmt->tls_epilogue) - stmt->tls_epilogue->accept(this); } static bool run(IRNode *node, const CompileConfig &config) { diff --git a/taichi/transforms/loop_invariant_detector.h b/taichi/transforms/loop_invariant_detector.h new file mode 100644 index 0000000000000..b4d2a3568c96e --- /dev/null +++ b/taichi/transforms/loop_invariant_detector.h @@ -0,0 +1,124 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" +#include "taichi/system/profiler.h" + +#include + +TLANG_NAMESPACE_BEGIN + +class LoopInvariantDetector : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + + std::stack loop_blocks; + + const CompileConfig &config; + + explicit LoopInvariantDetector(const CompileConfig &config) + : config(config) { + allow_undefined_visitor = true; + } + + bool is_loop_invariant(Stmt *stmt, Block *current_scope) { + if (loop_blocks.size() <= 1 || (!config.move_loop_invariant_outside_if && + current_scope != loop_blocks.top())) + return false; + + bool is_invariant = true; + + for (Stmt *operand : stmt->get_operands()) { + if (operand->parent == current_scope) { + // This statement has an operand that is in the current scope, + // so it can not be moved out of the scope. + is_invariant = false; + break; + } + if (current_scope != loop_blocks.top()) { + // If we enable moving code from a nested if block, we need to check + // visibility. Example: + // for i in range(10): + // a = x[0] + // if b: + // c = a + 1 + // Since we are moving statements outside the cloest for scope, + // We need to check the scope of the operand + Stmt *operand_parent = operand; + while (operand_parent->parent) { + operand_parent = operand_parent->parent->parent_stmt; + if (!operand_parent) + break; + // If the one of the current_scope of the operand is the top loop scope + // Then it will not be visible if we move it outside the top loop + // scope + if (operand_parent == loop_blocks.top()->parent_stmt) { + is_invariant = false; + break; + } + } + if (!is_invariant) + break; + } + } + + return is_invariant; + } + + void visit(Block *stmt_list) override { + for (auto &stmt : stmt_list->statements) + stmt->accept(this); + } + + void visit_loop(Block *body) { + loop_blocks.push(body); + + body->accept(this); + + loop_blocks.pop(); + } + + void visit(RangeForStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(StructForStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(MeshForStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(WhileStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(OffloadedStmt *stmt) override { + if (stmt->tls_prologue) + stmt->tls_prologue->accept(this); + + if (stmt->mesh_prologue) + stmt->mesh_prologue->accept(this); + + if (stmt->bls_prologue) + stmt->bls_prologue->accept(this); + + if (stmt->body) { + if (stmt->task_type == OffloadedStmt::TaskType::range_for || + stmt->task_type == OffloadedStmt::TaskType::struct_for) + visit_loop(stmt->body.get()); + else + stmt->body->accept(this); + } + + if (stmt->bls_epilogue) + stmt->bls_epilogue->accept(this); + + if (stmt->tls_epilogue) + stmt->tls_epilogue->accept(this); + } + +}; + +TLANG_NAMESPACE_END diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 6c18c88a9194d..a727a5162a9be 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -632,6 +632,8 @@ void full_simplify(IRNode *root, modified = true; if (config.opt_level > 0 && whole_kernel_cse(root)) modified = true; + if (cache_loop_invariant_global_vars(root, config)) + modified = true; // Don't do this time-consuming optimization pass again if the IR is // not modified. if (config.opt_level > 0 && (first_iteration || modified) && From 79e714061c7e69f05f7dc79035951cbf9c532518 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Sep 2022 10:37:51 +0000 Subject: [PATCH 02/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/ir/transforms.h | 3 +- .../cache_loop_invariant_global_vars.cpp | 36 +++++++++++-------- taichi/transforms/compile_to_offloads.cpp | 4 +-- taichi/transforms/loop_invariant_detector.h | 10 +++--- 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index fba2f0c75660b..77ca33f2450cd 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -47,7 +47,8 @@ bool whole_kernel_cse(IRNode *root); bool extract_constant(IRNode *root, const CompileConfig &config); bool unreachable_code_elimination(IRNode *root); bool loop_invariant_code_motion(IRNode *root, const CompileConfig &config); -bool cache_loop_invariant_global_vars(IRNode *root, const CompileConfig &config); +bool cache_loop_invariant_global_vars(IRNode *root, + const CompileConfig &config); void full_simplify(IRNode *root, const CompileConfig &config, const FullSimplifyPass::Args &args); diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index 107d2600332cf..864026314da59 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -6,12 +6,10 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { public: using LoopInvariantDetector::visit; - enum class CacheStatus { - ReadOnly = 0, - ReadWrite = 1 - }; + enum class CacheStatus { ReadOnly = 0, ReadWrite = 1 }; - std::unordered_map> cached_allocas; + std::unordered_map> + cached_allocas; DelayedIRModifier modifier; @@ -21,25 +19,31 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { void add_writeback(AllocaStmt *alloca_stmt, Stmt *stmt, Stmt *parent_stmt) { auto final_value = std::make_unique(alloca_stmt); - auto global_store = std::make_unique(stmt, final_value.get()); + auto global_store = + std::make_unique(stmt, final_value.get()); modifier.insert_after(parent_stmt, std::move(global_store)); modifier.insert_after(parent_stmt, std::move(final_value)); } - AllocaStmt *cache_global_to_local(Stmt *stmt, Stmt *parent_stmt, CacheStatus status) { + AllocaStmt *cache_global_to_local(Stmt *stmt, + Stmt *parent_stmt, + CacheStatus status) { if (auto &[cached, cached_status] = cached_allocas[stmt]; cached) { - if (cached_status == CacheStatus::ReadOnly && status == CacheStatus::ReadWrite) { + if (cached_status == CacheStatus::ReadOnly && + status == CacheStatus::ReadWrite) { add_writeback(cached, stmt, parent_stmt); cached_status = CacheStatus::ReadWrite; } return cached; } - auto alloca_unique = std::make_unique(stmt->ret_type.ptr_removed()); + auto alloca_unique = + std::make_unique(stmt->ret_type.ptr_removed()); auto alloca_stmt = alloca_unique.get(); cached_allocas[stmt] = {alloca_stmt, status}; auto new_global_load = std::make_unique(stmt); - auto local_store = std::make_unique(alloca_stmt, new_global_load.get()); + auto local_store = + std::make_unique(alloca_stmt, new_global_load.get()); modifier.insert_before(parent_stmt, std::move(new_global_load)); modifier.insert_before(parent_stmt, std::move(alloca_unique)); modifier.insert_before(parent_stmt, std::move(local_store)); @@ -52,7 +56,8 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { void visit(GlobalLoadStmt *stmt) override { if (is_loop_invariant(stmt->src, stmt->parent)) { - auto alloca_stmt = cache_global_to_local(stmt->src, stmt->parent->parent_stmt, CacheStatus::ReadOnly); + auto alloca_stmt = cache_global_to_local( + stmt->src, stmt->parent->parent_stmt, CacheStatus::ReadOnly); auto local_load = std::make_unique(alloca_stmt); stmt->replace_usages_with(local_load.get()); modifier.insert_before(stmt, std::move(local_load)); @@ -62,8 +67,10 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { void visit(GlobalStoreStmt *stmt) override { if (is_loop_invariant(stmt->dest, stmt->parent)) { - auto alloca_stmt = cache_global_to_local(stmt->dest, stmt->parent->parent_stmt, CacheStatus::ReadWrite); - auto local_store = std::make_unique(alloca_stmt, stmt->val); + auto alloca_stmt = cache_global_to_local( + stmt->dest, stmt->parent->parent_stmt, CacheStatus::ReadWrite); + auto local_store = + std::make_unique(alloca_stmt, stmt->val); stmt->replace_usages_with(local_store.get()); modifier.insert_before(stmt, std::move(local_store)); modifier.erase(stmt); @@ -87,7 +94,8 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { }; namespace irpass { -bool cache_loop_invariant_global_vars(IRNode *root, const CompileConfig &config) { +bool cache_loop_invariant_global_vars(IRNode *root, + const CompileConfig &config) { TI_AUTO_PROF; return CacheLoopInvariantGlobalVars::run(root, config); } diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index f540ddb39622f..48f718e612b3f 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -246,8 +246,8 @@ void offload_to_executable(IRNode *ir, irpass::analysis::verify(ir); if (lower_global_access) { - irpass::full_simplify( - ir, config, {false, /*autodiff_enabled*/ false, kernel->program}); + irpass::full_simplify(ir, config, + {false, /*autodiff_enabled*/ false, kernel->program}); print("Simplified before lower access"); irpass::lower_access(ir, config, {kernel->no_activate, true}); print("Access lowered"); diff --git a/taichi/transforms/loop_invariant_detector.h b/taichi/transforms/loop_invariant_detector.h index b4d2a3568c96e..5898fa28d7469 100644 --- a/taichi/transforms/loop_invariant_detector.h +++ b/taichi/transforms/loop_invariant_detector.h @@ -16,8 +16,7 @@ class LoopInvariantDetector : public BasicStmtVisitor { const CompileConfig &config; - explicit LoopInvariantDetector(const CompileConfig &config) - : config(config) { + explicit LoopInvariantDetector(const CompileConfig &config) : config(config) { allow_undefined_visitor = true; } @@ -49,9 +48,9 @@ class LoopInvariantDetector : public BasicStmtVisitor { operand_parent = operand_parent->parent->parent_stmt; if (!operand_parent) break; - // If the one of the current_scope of the operand is the top loop scope - // Then it will not be visible if we move it outside the top loop - // scope + // If the one of the current_scope of the operand is the top loop + // scope Then it will not be visible if we move it outside the top + // loop scope if (operand_parent == loop_blocks.top()->parent_stmt) { is_invariant = false; break; @@ -118,7 +117,6 @@ class LoopInvariantDetector : public BasicStmtVisitor { if (stmt->tls_epilogue) stmt->tls_epilogue->accept(this); } - }; TLANG_NAMESPACE_END From 37cbc1965ca70308300c2085ad10d0be7ca44767 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Fri, 16 Sep 2022 17:09:47 +0800 Subject: [PATCH 03/16] fix --- .../cache_loop_invariant_global_vars.cpp | 63 ++++++++++++------- taichi/transforms/compile_to_offloads.cpp | 2 + .../transforms/loop_invariant_code_motion.cpp | 6 +- taichi/transforms/loop_invariant_detector.h | 6 +- taichi/transforms/simplify.cpp | 2 - 5 files changed, 50 insertions(+), 29 deletions(-) diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index 864026314da59..211e88af9be8d 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -6,10 +6,10 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { public: using LoopInvariantDetector::visit; - enum class CacheStatus { ReadOnly = 0, ReadWrite = 1 }; + enum class CacheStatus { Read = 1, Write = 2, ReadWrite = 3 }; - std::unordered_map> - cached_allocas; + typedef std::unordered_map> CacheMap; + std::stack cached_maps; DelayedIRModifier modifier; @@ -17,21 +17,39 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { : LoopInvariantDetector(config) { } - void add_writeback(AllocaStmt *alloca_stmt, Stmt *stmt, Stmt *parent_stmt) { + void visit_loop(Block *body) override { + cached_maps.emplace(); + LoopInvariantDetector::visit_loop(body); + cached_maps.pop(); + } + + void add_writeback(AllocaStmt *alloca_stmt, Stmt *global_var) { auto final_value = std::make_unique(alloca_stmt); auto global_store = - std::make_unique(stmt, final_value.get()); - modifier.insert_after(parent_stmt, std::move(global_store)); - modifier.insert_after(parent_stmt, std::move(final_value)); + std::make_unique(global_var, final_value.get()); + modifier.insert_after(current_loop_stmt(), std::move(global_store)); + modifier.insert_after(current_loop_stmt(), std::move(final_value)); + } + + void set_init_value(AllocaStmt *alloca_stmt, Stmt *global_var) { + auto new_global_load = std::make_unique(global_var); + auto local_store = + std::make_unique(alloca_stmt, new_global_load.get()); + modifier.insert_before(current_loop_stmt(), std::move(new_global_load)); + modifier.insert_before(current_loop_stmt(), std::move(local_store)); } + /* + * + */ AllocaStmt *cache_global_to_local(Stmt *stmt, - Stmt *parent_stmt, CacheStatus status) { - if (auto &[cached, cached_status] = cached_allocas[stmt]; cached) { - if (cached_status == CacheStatus::ReadOnly && - status == CacheStatus::ReadWrite) { - add_writeback(cached, stmt, parent_stmt); + if (auto &[cached, cached_status] = cached_maps.top()[stmt]; cached) { + // The global variable has already been cached. + if (cached_status == CacheStatus::Read && + status == CacheStatus::Write) { + // If the + add_writeback(cached, stmt); cached_status = CacheStatus::ReadWrite; } return cached; @@ -40,16 +58,15 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { auto alloca_unique = std::make_unique(stmt->ret_type.ptr_removed()); auto alloca_stmt = alloca_unique.get(); - cached_allocas[stmt] = {alloca_stmt, status}; - auto new_global_load = std::make_unique(stmt); - auto local_store = - std::make_unique(alloca_stmt, new_global_load.get()); - modifier.insert_before(parent_stmt, std::move(new_global_load)); - modifier.insert_before(parent_stmt, std::move(alloca_unique)); - modifier.insert_before(parent_stmt, std::move(local_store)); + modifier.insert_before(loop_blocks.top()->parent_stmt, std::move(alloca_unique)); + cached_maps.top()[stmt] = {alloca_stmt, status}; + + if (status == CacheStatus::Read) { + set_init_value(alloca_stmt, stmt); + } - if (status == CacheStatus::ReadWrite) { - add_writeback(alloca_stmt, stmt, parent_stmt); + if (status == CacheStatus::Write) { + add_writeback(alloca_stmt, stmt); } return alloca_stmt; } @@ -57,7 +74,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { void visit(GlobalLoadStmt *stmt) override { if (is_loop_invariant(stmt->src, stmt->parent)) { auto alloca_stmt = cache_global_to_local( - stmt->src, stmt->parent->parent_stmt, CacheStatus::ReadOnly); + stmt->src, CacheStatus::Read); auto local_load = std::make_unique(alloca_stmt); stmt->replace_usages_with(local_load.get()); modifier.insert_before(stmt, std::move(local_load)); @@ -68,7 +85,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { void visit(GlobalStoreStmt *stmt) override { if (is_loop_invariant(stmt->dest, stmt->parent)) { auto alloca_stmt = cache_global_to_local( - stmt->dest, stmt->parent->parent_stmt, CacheStatus::ReadWrite); + stmt->dest, CacheStatus::Write); auto local_store = std::make_unique(alloca_stmt, stmt->val); stmt->replace_usages_with(local_store.get()); diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 48f718e612b3f..03af3ab5754dd 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -246,6 +246,8 @@ void offload_to_executable(IRNode *ir, irpass::analysis::verify(ir); if (lower_global_access) { + irpass::cache_loop_invariant_global_vars(ir, config); + print("Cache loop-invariant global vars"); irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false, kernel->program}); print("Simplified before lower access"); diff --git a/taichi/transforms/loop_invariant_code_motion.cpp b/taichi/transforms/loop_invariant_code_motion.cpp index 7529f28a4f87f..f7105f76c3ef9 100644 --- a/taichi/transforms/loop_invariant_code_motion.cpp +++ b/taichi/transforms/loop_invariant_code_motion.cpp @@ -17,7 +17,7 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); - modifier.insert_before(stmt->parent->parent_stmt, std::move(replacement)); + modifier.insert_before(current_loop_stmt(), std::move(replacement)); modifier.erase(stmt); } } @@ -27,7 +27,7 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); - modifier.insert_before(stmt->parent->parent_stmt, std::move(replacement)); + modifier.insert_before(current_loop_stmt(), std::move(replacement)); modifier.erase(stmt); } } @@ -37,7 +37,7 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); - modifier.insert_before(stmt->parent->parent_stmt, std::move(replacement)); + modifier.insert_before(current_loop_stmt(), std::move(replacement)); modifier.erase(stmt); } } diff --git a/taichi/transforms/loop_invariant_detector.h b/taichi/transforms/loop_invariant_detector.h index 5898fa28d7469..09ed962a1456b 100644 --- a/taichi/transforms/loop_invariant_detector.h +++ b/taichi/transforms/loop_invariant_detector.h @@ -64,12 +64,16 @@ class LoopInvariantDetector : public BasicStmtVisitor { return is_invariant; } + Stmt *current_loop_stmt() { + return loop_blocks.top()->parent_stmt; + } + void visit(Block *stmt_list) override { for (auto &stmt : stmt_list->statements) stmt->accept(this); } - void visit_loop(Block *body) { + virtual void visit_loop(Block *body) { loop_blocks.push(body); body->accept(this); diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index a727a5162a9be..6c18c88a9194d 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -632,8 +632,6 @@ void full_simplify(IRNode *root, modified = true; if (config.opt_level > 0 && whole_kernel_cse(root)) modified = true; - if (cache_loop_invariant_global_vars(root, config)) - modified = true; // Don't do this time-consuming optimization pass again if the IR is // not modified. if (config.opt_level > 0 && (first_iteration || modified) && From 43f659579a7f9d0daa328951a576232d50d77667 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Fri, 16 Sep 2022 17:49:53 +0800 Subject: [PATCH 04/16] fix --- .../cache_loop_invariant_global_vars.cpp | 4 +- taichi/transforms/loop_invariant_detector.h | 63 ++++++++++--------- tests/python/test_tuple_assign.py | 2 +- 3 files changed, 35 insertions(+), 34 deletions(-) diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index 211e88af9be8d..cfce2f498fec3 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -72,7 +72,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { } void visit(GlobalLoadStmt *stmt) override { - if (is_loop_invariant(stmt->src, stmt->parent)) { + if (is_operand_loop_invariant(stmt->src, stmt->parent)) { auto alloca_stmt = cache_global_to_local( stmt->src, CacheStatus::Read); auto local_load = std::make_unique(alloca_stmt); @@ -83,7 +83,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { } void visit(GlobalStoreStmt *stmt) override { - if (is_loop_invariant(stmt->dest, stmt->parent)) { + if (is_operand_loop_invariant(stmt->dest, stmt->parent)) { auto alloca_stmt = cache_global_to_local( stmt->dest, CacheStatus::Write); auto local_store = diff --git a/taichi/transforms/loop_invariant_detector.h b/taichi/transforms/loop_invariant_detector.h index 09ed962a1456b..d23ab87ee4fb2 100644 --- a/taichi/transforms/loop_invariant_detector.h +++ b/taichi/transforms/loop_invariant_detector.h @@ -20,6 +20,37 @@ class LoopInvariantDetector : public BasicStmtVisitor { allow_undefined_visitor = true; } + bool is_operand_loop_invariant(Stmt *operand, Block *current_scope) { + if (operand->parent == current_scope) { + // This statement has an operand that is in the current scope, + // so it can not be moved out of the scope. + return false; + } + if (current_scope != loop_blocks.top()) { + // If we enable moving code from a nested if block, we need to check + // visibility. Example: + // for i in range(10): + // a = x[0] + // if b: + // c = a + 1 + // Since we are moving statements outside the cloest for scope, + // We need to check the scope of the operand + Stmt *operand_parent = operand; + while (operand_parent->parent) { + operand_parent = operand_parent->parent->parent_stmt; + if (!operand_parent) + break; + // If the one of the current_scope of the operand is the top loop + // scope Then it will not be visible if we move it outside the top + // loop scope + if (operand_parent == loop_blocks.top()->parent_stmt) { + return false; + } + } + } + return true; + } + bool is_loop_invariant(Stmt *stmt, Block *current_scope) { if (loop_blocks.size() <= 1 || (!config.move_loop_invariant_outside_if && current_scope != loop_blocks.top())) @@ -28,37 +59,7 @@ class LoopInvariantDetector : public BasicStmtVisitor { bool is_invariant = true; for (Stmt *operand : stmt->get_operands()) { - if (operand->parent == current_scope) { - // This statement has an operand that is in the current scope, - // so it can not be moved out of the scope. - is_invariant = false; - break; - } - if (current_scope != loop_blocks.top()) { - // If we enable moving code from a nested if block, we need to check - // visibility. Example: - // for i in range(10): - // a = x[0] - // if b: - // c = a + 1 - // Since we are moving statements outside the cloest for scope, - // We need to check the scope of the operand - Stmt *operand_parent = operand; - while (operand_parent->parent) { - operand_parent = operand_parent->parent->parent_stmt; - if (!operand_parent) - break; - // If the one of the current_scope of the operand is the top loop - // scope Then it will not be visible if we move it outside the top - // loop scope - if (operand_parent == loop_blocks.top()->parent_stmt) { - is_invariant = false; - break; - } - } - if (!is_invariant) - break; - } + is_invariant &= is_operand_loop_invariant(operand, current_scope); } return is_invariant; diff --git a/tests/python/test_tuple_assign.py b/tests/python/test_tuple_assign.py index 03dc05bec7d18..caf7939b4c278 100644 --- a/tests/python/test_tuple_assign.py +++ b/tests/python/test_tuple_assign.py @@ -5,7 +5,7 @@ from tests import test_utils -@test_utils.test() +@test_utils.test(print_ir=True) def test_fibonacci(): @ti.kernel def ti_fibonacci(n: ti.i32) -> ti.i32: From b14ba578c1dbc2b16fce19390d01138c6b07b159 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Fri, 16 Sep 2022 17:55:45 +0800 Subject: [PATCH 05/16] fix --- .../transforms/cache_loop_invariant_global_vars.cpp | 4 ++-- taichi/transforms/loop_invariant_code_motion.cpp | 10 ++++++++++ taichi/transforms/loop_invariant_detector.h | 11 +++++++++-- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index cfce2f498fec3..53e20f727bf0d 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -72,7 +72,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { } void visit(GlobalLoadStmt *stmt) override { - if (is_operand_loop_invariant(stmt->src, stmt->parent)) { + if (is_operand_loop_invariant_impl(stmt->src, stmt->parent)) { auto alloca_stmt = cache_global_to_local( stmt->src, CacheStatus::Read); auto local_load = std::make_unique(alloca_stmt); @@ -83,7 +83,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { } void visit(GlobalStoreStmt *stmt) override { - if (is_operand_loop_invariant(stmt->dest, stmt->parent)) { + if (is_operand_loop_invariant_impl(stmt->dest, stmt->parent)) { auto alloca_stmt = cache_global_to_local( stmt->dest, CacheStatus::Write); auto local_store = diff --git a/taichi/transforms/loop_invariant_code_motion.cpp b/taichi/transforms/loop_invariant_code_motion.cpp index f7105f76c3ef9..3d979a3ec7491 100644 --- a/taichi/transforms/loop_invariant_code_motion.cpp +++ b/taichi/transforms/loop_invariant_code_motion.cpp @@ -42,6 +42,16 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { } } + void visit(GlobalTemporaryStmt *stmt) override { + if (is_loop_invariant(stmt, stmt->parent)) { + auto replacement = stmt->clone(); + stmt->replace_usages_with(replacement.get()); + + modifier.insert_before(current_loop_stmt(), std::move(replacement)); + modifier.erase(stmt); + } + } + static bool run(IRNode *node, const CompileConfig &config) { bool modified = false; diff --git a/taichi/transforms/loop_invariant_detector.h b/taichi/transforms/loop_invariant_detector.h index d23ab87ee4fb2..efa98e8fee7dd 100644 --- a/taichi/transforms/loop_invariant_detector.h +++ b/taichi/transforms/loop_invariant_detector.h @@ -20,7 +20,7 @@ class LoopInvariantDetector : public BasicStmtVisitor { allow_undefined_visitor = true; } - bool is_operand_loop_invariant(Stmt *operand, Block *current_scope) { + bool is_operand_loop_invariant_impl(Stmt *operand, Block *current_scope) { if (operand->parent == current_scope) { // This statement has an operand that is in the current scope, // so it can not be moved out of the scope. @@ -51,6 +51,13 @@ class LoopInvariantDetector : public BasicStmtVisitor { return true; } + bool is_operand_loop_invariant(Stmt *operand, Block *current_scope) { + if (loop_blocks.size() <= 1 || (!config.move_loop_invariant_outside_if && + current_scope != loop_blocks.top())) + return false; + return is_operand_loop_invariant_impl(operand, current_scope); + } + bool is_loop_invariant(Stmt *stmt, Block *current_scope) { if (loop_blocks.size() <= 1 || (!config.move_loop_invariant_outside_if && current_scope != loop_blocks.top())) @@ -59,7 +66,7 @@ class LoopInvariantDetector : public BasicStmtVisitor { bool is_invariant = true; for (Stmt *operand : stmt->get_operands()) { - is_invariant &= is_operand_loop_invariant(operand, current_scope); + is_invariant &= is_operand_loop_invariant_impl(operand, current_scope); } return is_invariant; From 6bcfc19dfa27c5edd9365b413d42d944cd7ab7a9 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Mon, 19 Sep 2022 14:21:36 +0800 Subject: [PATCH 06/16] fix --- .../cache_loop_invariant_global_vars.cpp | 138 +++++++++++++----- taichi/transforms/compile_to_offloads.cpp | 5 +- taichi/transforms/loop_invariant_detector.h | 4 +- 3 files changed, 103 insertions(+), 44 deletions(-) diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index 53e20f727bf0d..6f68e2d8600ec 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -1,4 +1,5 @@ #include "taichi/transforms/loop_invariant_detector.h" +#include "taichi/ir/analysis.h" TLANG_NAMESPACE_BEGIN @@ -6,20 +7,104 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { public: using LoopInvariantDetector::visit; - enum class CacheStatus { Read = 1, Write = 2, ReadWrite = 3 }; + enum class CacheStatus { None = 0, Read = 1, Write = 2, ReadWrite = 3, HasAtomic = 4 }; - typedef std::unordered_map> CacheMap; + typedef std::unordered_map>> CacheMap; std::stack cached_maps; DelayedIRModifier modifier; + std::unordered_map loop_unique_ptr_; + OffloadedStmt *current_offloaded; explicit CacheLoopInvariantGlobalVars(const CompileConfig &config) : LoopInvariantDetector(config) { } + void visit(OffloadedStmt *stmt) override { + if (stmt->task_type == OffloadedTaskType::range_for || + stmt->task_type == OffloadedTaskType::mesh_for || + stmt->task_type == OffloadedTaskType::struct_for) { + auto uniquely_accessed_pointers = + irpass::analysis::gather_uniquely_accessed_pointers(stmt); + loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first); + } + current_offloaded = stmt; + // We don't need to visit TLS/BLS prologues/epilogues. + if (stmt->body) { + if (stmt->task_type == OffloadedStmt::TaskType::range_for || + stmt->task_type == OffloadedTaskType::mesh_for || + stmt->task_type == OffloadedStmt::TaskType::struct_for) + visit_loop(stmt->body.get()); + else + stmt->body->accept(this); + } + current_offloaded = nullptr; + } + + bool is_offload_unique(Stmt *stmt) { + if (current_offloaded->task_type == OffloadedTaskType::serial) { + return true; + } + if (auto global_ptr = stmt->cast()) { + auto snode = global_ptr->snode; + if (loop_unique_ptr_[snode] == nullptr || + loop_unique_ptr_[snode]->indices.empty()) { + // not uniquely accessed + return false; + } + if (current_offloaded->mem_access_opt.has_flag( + snode, SNodeAccessFlag::block_local) || + current_offloaded->mem_access_opt.has_flag( + snode, SNodeAccessFlag::mesh_local)) { + // BLS does not support write access yet so we keep atomic_adds. + return false; + } + return true; + } + return false; + } + void visit_loop(Block *body) override { cached_maps.emplace(); - LoopInvariantDetector::visit_loop(body); + + loop_blocks.push(body); + + body->accept(this); + + for (auto &[dest, status_vec] : cached_maps.top()) { + auto &[status, vec] = status_vec; + if (status == CacheStatus::HasAtomic) { + continue; + } + auto alloca_unique = + std::make_unique(dest->ret_type.ptr_removed()); + auto alloca_stmt = alloca_unique.get(); + modifier.insert_before(body->parent_stmt, std::move(alloca_unique)); + if (int(status) & int(CacheStatus::Read)) { + set_init_value(alloca_stmt, dest); + } + if (int(status) & int(CacheStatus::Write)) { + add_writeback(alloca_stmt, dest); + } + for (auto *stmt : vec) { + if (stmt->is()) { + auto local_load = std::make_unique(alloca_stmt); + stmt->replace_usages_with(local_load.get()); + modifier.insert_before(stmt, std::move(local_load)); + modifier.erase(stmt); + } + else if (auto *global_store = stmt->cast()) { + auto local_store = + std::make_unique(alloca_stmt, global_store->val); + stmt->replace_usages_with(local_store.get()); + modifier.insert_before(stmt, std::move(local_store)); + modifier.erase(stmt); + } else { + TI_UNREACHABLE + } + } + } + loop_blocks.pop(); cached_maps.pop(); } @@ -39,58 +124,32 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { modifier.insert_before(current_loop_stmt(), std::move(local_store)); } - /* - * - */ - AllocaStmt *cache_global_to_local(Stmt *stmt, + void cache_global_to_local(Stmt *stmt, Stmt *dest, CacheStatus status) { - if (auto &[cached, cached_status] = cached_maps.top()[stmt]; cached) { + if (auto &[cached_status, vec] = cached_maps.top()[dest]; cached_status != CacheStatus::None) { // The global variable has already been cached. if (cached_status == CacheStatus::Read && status == CacheStatus::Write) { - // If the - add_writeback(cached, stmt); cached_status = CacheStatus::ReadWrite; } - return cached; + vec.push_back(stmt); + return; } - - auto alloca_unique = - std::make_unique(stmt->ret_type.ptr_removed()); - auto alloca_stmt = alloca_unique.get(); - modifier.insert_before(loop_blocks.top()->parent_stmt, std::move(alloca_unique)); - cached_maps.top()[stmt] = {alloca_stmt, status}; - - if (status == CacheStatus::Read) { - set_init_value(alloca_stmt, stmt); - } - - if (status == CacheStatus::Write) { - add_writeback(alloca_stmt, stmt); - } - return alloca_stmt; + cached_maps.top()[dest] = {status, {stmt}}; } void visit(GlobalLoadStmt *stmt) override { - if (is_operand_loop_invariant_impl(stmt->src, stmt->parent)) { - auto alloca_stmt = cache_global_to_local( + if (is_offload_unique(stmt->src) && is_operand_loop_invariant_impl(stmt->src, stmt->parent)) { + cache_global_to_local(stmt, stmt->src, CacheStatus::Read); - auto local_load = std::make_unique(alloca_stmt); - stmt->replace_usages_with(local_load.get()); - modifier.insert_before(stmt, std::move(local_load)); - modifier.erase(stmt); } } void visit(GlobalStoreStmt *stmt) override { - if (is_operand_loop_invariant_impl(stmt->dest, stmt->parent)) { - auto alloca_stmt = cache_global_to_local( + if (is_offload_unique(stmt->dest) && is_operand_loop_invariant_impl(stmt->dest, stmt->parent)) { + cache_global_to_local(stmt, stmt->dest, CacheStatus::Write); - auto local_store = - std::make_unique(alloca_stmt, stmt->val); - stmt->replace_usages_with(local_store.get()); - modifier.insert_before(stmt, std::move(local_store)); - modifier.erase(stmt); + } } @@ -99,6 +158,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { while (true) { CacheLoopInvariantGlobalVars eliminator(config); + irpass::print(node); node->accept(&eliminator); if (eliminator.modifier.modify_ir()) modified = true; diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 03af3ab5754dd..31591f5311f58 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -186,7 +186,8 @@ void offload_to_executable(IRNode *ir, irpass::demote_atomics(ir, config); print("Atomics demoted I"); irpass::analysis::verify(ir); - + irpass::cache_loop_invariant_global_vars(ir, config); + print("Cache loop-invariant global vars"); if (config.demote_dense_struct_fors) { irpass::demote_dense_struct_fors(ir, config.packed); irpass::type_check(ir, config); @@ -246,8 +247,6 @@ void offload_to_executable(IRNode *ir, irpass::analysis::verify(ir); if (lower_global_access) { - irpass::cache_loop_invariant_global_vars(ir, config); - print("Cache loop-invariant global vars"); irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false, kernel->program}); print("Simplified before lower access"); diff --git a/taichi/transforms/loop_invariant_detector.h b/taichi/transforms/loop_invariant_detector.h index efa98e8fee7dd..0591e39df6ede 100644 --- a/taichi/transforms/loop_invariant_detector.h +++ b/taichi/transforms/loop_invariant_detector.h @@ -52,8 +52,7 @@ class LoopInvariantDetector : public BasicStmtVisitor { } bool is_operand_loop_invariant(Stmt *operand, Block *current_scope) { - if (loop_blocks.size() <= 1 || (!config.move_loop_invariant_outside_if && - current_scope != loop_blocks.top())) + if (loop_blocks.size() <= 1) return false; return is_operand_loop_invariant_impl(operand, current_scope); } @@ -117,6 +116,7 @@ class LoopInvariantDetector : public BasicStmtVisitor { if (stmt->body) { if (stmt->task_type == OffloadedStmt::TaskType::range_for || + stmt->task_type == OffloadedTaskType::mesh_for || stmt->task_type == OffloadedStmt::TaskType::struct_for) visit_loop(stmt->body.get()); else From 16d6d010c2bb6ef1b677db9ad8ee193a05391707 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Mon, 19 Sep 2022 15:03:59 +0800 Subject: [PATCH 07/16] fix --- .../cache_loop_invariant_global_vars.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index 6f68e2d8600ec..e8c647a9500ed 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -14,6 +14,8 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { DelayedIRModifier modifier; std::unordered_map loop_unique_ptr_; + std::unordered_map loop_unique_arr_ptr_; + OffloadedStmt *current_offloaded; explicit CacheLoopInvariantGlobalVars(const CompileConfig &config) @@ -27,6 +29,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { auto uniquely_accessed_pointers = irpass::analysis::gather_uniquely_accessed_pointers(stmt); loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first); + loop_unique_arr_ptr_ = std::move(uniquely_accessed_pointers.second); } current_offloaded = stmt; // We don't need to visit TLS/BLS prologues/epilogues. @@ -60,6 +63,19 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { return false; } return true; + } else if (stmt->is()) { + ExternalPtrStmt *dest_ptr = stmt->as(); + if (dest_ptr->indices.empty()) { + return false; + } + ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as(); + int arg_id = arg_load_stmt->arg_id; + if (loop_unique_arr_ptr_[arg_id] == nullptr) { + // Not loop unique + return false; + } + return true; + // TODO: Is BLS / Mem Access Opt a thing for any_arr? } return false; } From d12dd57ccd8af9be83272824d2dc383eee9f6338 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Mon, 19 Sep 2022 15:28:01 +0800 Subject: [PATCH 08/16] stash --- .../cache_loop_invariant_global_vars.cpp | 1 - .../transforms/loop_invariant_code_motion.cpp | 20 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index e8c647a9500ed..808d8f3a5cbe9 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -174,7 +174,6 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { while (true) { CacheLoopInvariantGlobalVars eliminator(config); - irpass::print(node); node->accept(&eliminator); if (eliminator.modifier.modify_ir()) modified = true; diff --git a/taichi/transforms/loop_invariant_code_motion.cpp b/taichi/transforms/loop_invariant_code_motion.cpp index 3d979a3ec7491..dc1ad5c0c023c 100644 --- a/taichi/transforms/loop_invariant_code_motion.cpp +++ b/taichi/transforms/loop_invariant_code_motion.cpp @@ -52,6 +52,26 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { } } + void visit(ExternalPtrStmt *stmt) override { + if (is_loop_invariant(stmt, stmt->parent)) { + auto replacement = stmt->clone(); + stmt->replace_usages_with(replacement.get()); + + modifier.insert_before(current_loop_stmt(), std::move(replacement)); + modifier.erase(stmt); + } + } + + void visit(ArgLoadStmt *stmt) override { + if (is_loop_invariant(stmt, stmt->parent)) { + auto replacement = stmt->clone(); + stmt->replace_usages_with(replacement.get()); + + modifier.insert_before(current_loop_stmt(), std::move(replacement)); + modifier.erase(stmt); + } + } + static bool run(IRNode *node, const CompileConfig &config) { bool modified = false; From 1f09e3547874c46b687006d474341a7fda2567c9 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Mon, 19 Sep 2022 16:12:42 +0800 Subject: [PATCH 09/16] fix --- taichi/transforms/cache_loop_invariant_global_vars.cpp | 4 ++-- taichi/transforms/loop_invariant_code_motion.cpp | 10 ---------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index 808d8f3a5cbe9..f9a7154d10a85 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -155,14 +155,14 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { } void visit(GlobalLoadStmt *stmt) override { - if (is_offload_unique(stmt->src) && is_operand_loop_invariant_impl(stmt->src, stmt->parent)) { + if (is_offload_unique(stmt->src) && is_operand_loop_invariant(stmt->src, stmt->parent)) { cache_global_to_local(stmt, stmt->src, CacheStatus::Read); } } void visit(GlobalStoreStmt *stmt) override { - if (is_offload_unique(stmt->dest) && is_operand_loop_invariant_impl(stmt->dest, stmt->parent)) { + if (is_offload_unique(stmt->dest) && is_operand_loop_invariant(stmt->dest, stmt->parent)) { cache_global_to_local(stmt, stmt->dest, CacheStatus::Write); diff --git a/taichi/transforms/loop_invariant_code_motion.cpp b/taichi/transforms/loop_invariant_code_motion.cpp index dc1ad5c0c023c..0f8ec5fa7b8ab 100644 --- a/taichi/transforms/loop_invariant_code_motion.cpp +++ b/taichi/transforms/loop_invariant_code_motion.cpp @@ -42,16 +42,6 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { } } - void visit(GlobalTemporaryStmt *stmt) override { - if (is_loop_invariant(stmt, stmt->parent)) { - auto replacement = stmt->clone(); - stmt->replace_usages_with(replacement.get()); - - modifier.insert_before(current_loop_stmt(), std::move(replacement)); - modifier.erase(stmt); - } - } - void visit(ExternalPtrStmt *stmt) override { if (is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); From 9d5984e9d77011bac54780f9e308a7936195b2d6 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Mon, 19 Sep 2022 16:13:33 +0800 Subject: [PATCH 10/16] fix --- tests/python/test_tuple_assign.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_tuple_assign.py b/tests/python/test_tuple_assign.py index caf7939b4c278..03dc05bec7d18 100644 --- a/tests/python/test_tuple_assign.py +++ b/tests/python/test_tuple_assign.py @@ -5,7 +5,7 @@ from tests import test_utils -@test_utils.test(print_ir=True) +@test_utils.test() def test_fibonacci(): @ti.kernel def ti_fibonacci(n: ti.i32) -> ti.i32: From 1d646b4400245aaadf90dad108d9f204cbf7e8b3 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Tue, 20 Sep 2022 15:03:21 +0800 Subject: [PATCH 11/16] add compile config --- taichi/program/compile_config.h | 1 + taichi/python/export_lang.cpp | 1 + taichi/transforms/compile_to_offloads.cpp | 7 +++++-- taichi/transforms/loop_invariant_code_motion.cpp | 6 +++--- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index 704e25a480df1..f2903114ec2ba 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -28,6 +28,7 @@ struct CompileConfig { bool lower_access; bool simplify_after_lower_access; bool move_loop_invariant_outside_if; + bool cache_loop_invariant_global_vars{false}; bool demote_dense_struct_fors; bool advanced_optimization; bool constant_folding; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 3795ba380aebb..b5f7266f643f4 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -166,6 +166,7 @@ void export_lang(py::module &m) { .def_readwrite("lower_access", &CompileConfig::lower_access) .def_readwrite("move_loop_invariant_outside_if", &CompileConfig::move_loop_invariant_outside_if) + .def_readwrite("cache_loop_invariant_global_vars", &CompileConfig::cache_loop_invariant_global_vars) .def_readwrite("default_cpu_block_dim", &CompileConfig::default_cpu_block_dim) .def_readwrite("cpu_block_dim_adaptive", diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 31591f5311f58..58d00e0c2332f 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -186,8 +186,11 @@ void offload_to_executable(IRNode *ir, irpass::demote_atomics(ir, config); print("Atomics demoted I"); irpass::analysis::verify(ir); - irpass::cache_loop_invariant_global_vars(ir, config); - print("Cache loop-invariant global vars"); + if (config.cache_loop_invariant_global_vars) { + irpass::cache_loop_invariant_global_vars(ir, config); + print("Cache loop-invariant global vars"); + } + if (config.demote_dense_struct_fors) { irpass::demote_dense_struct_fors(ir, config.packed); irpass::type_check(ir, config); diff --git a/taichi/transforms/loop_invariant_code_motion.cpp b/taichi/transforms/loop_invariant_code_motion.cpp index 0f8ec5fa7b8ab..3a50421f19d99 100644 --- a/taichi/transforms/loop_invariant_code_motion.cpp +++ b/taichi/transforms/loop_invariant_code_motion.cpp @@ -33,7 +33,7 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { } void visit(GlobalPtrStmt *stmt) override { - if (is_loop_invariant(stmt, stmt->parent)) { + if (config.cache_loop_invariant_global_vars && is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); @@ -43,7 +43,7 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { } void visit(ExternalPtrStmt *stmt) override { - if (is_loop_invariant(stmt, stmt->parent)) { + if (config.cache_loop_invariant_global_vars && is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); @@ -53,7 +53,7 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { } void visit(ArgLoadStmt *stmt) override { - if (is_loop_invariant(stmt, stmt->parent)) { + if (config.cache_loop_invariant_global_vars && is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); From 45b2775fe720a646cca1f8ea23730910ac568f52 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Sep 2022 07:05:18 +0000 Subject: [PATCH 12/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/python/export_lang.cpp | 3 +- .../cache_loop_invariant_global_vars.cpp | 39 +++++++++++-------- .../transforms/loop_invariant_code_motion.cpp | 9 +++-- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index b5f7266f643f4..0c70a7170b1e8 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -166,7 +166,8 @@ void export_lang(py::module &m) { .def_readwrite("lower_access", &CompileConfig::lower_access) .def_readwrite("move_loop_invariant_outside_if", &CompileConfig::move_loop_invariant_outside_if) - .def_readwrite("cache_loop_invariant_global_vars", &CompileConfig::cache_loop_invariant_global_vars) + .def_readwrite("cache_loop_invariant_global_vars", + &CompileConfig::cache_loop_invariant_global_vars) .def_readwrite("default_cpu_block_dim", &CompileConfig::default_cpu_block_dim) .def_readwrite("cpu_block_dim_adaptive", diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index f9a7154d10a85..5f45ae1aa521b 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -7,9 +7,17 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { public: using LoopInvariantDetector::visit; - enum class CacheStatus { None = 0, Read = 1, Write = 2, ReadWrite = 3, HasAtomic = 4 }; - - typedef std::unordered_map>> CacheMap; + enum class CacheStatus { + None = 0, + Read = 1, + Write = 2, + ReadWrite = 3, + HasAtomic = 4 + }; + + typedef std::unordered_map>> + CacheMap; std::stack cached_maps; DelayedIRModifier modifier; @@ -108,8 +116,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { stmt->replace_usages_with(local_load.get()); modifier.insert_before(stmt, std::move(local_load)); modifier.erase(stmt); - } - else if (auto *global_store = stmt->cast()) { + } else if (auto *global_store = stmt->cast()) { auto local_store = std::make_unique(alloca_stmt, global_store->val); stmt->replace_usages_with(local_store.get()); @@ -140,12 +147,11 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { modifier.insert_before(current_loop_stmt(), std::move(local_store)); } - void cache_global_to_local(Stmt *stmt, Stmt *dest, - CacheStatus status) { - if (auto &[cached_status, vec] = cached_maps.top()[dest]; cached_status != CacheStatus::None) { + void cache_global_to_local(Stmt *stmt, Stmt *dest, CacheStatus status) { + if (auto &[cached_status, vec] = cached_maps.top()[dest]; + cached_status != CacheStatus::None) { // The global variable has already been cached. - if (cached_status == CacheStatus::Read && - status == CacheStatus::Write) { + if (cached_status == CacheStatus::Read && status == CacheStatus::Write) { cached_status = CacheStatus::ReadWrite; } vec.push_back(stmt); @@ -155,17 +161,16 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { } void visit(GlobalLoadStmt *stmt) override { - if (is_offload_unique(stmt->src) && is_operand_loop_invariant(stmt->src, stmt->parent)) { - cache_global_to_local(stmt, - stmt->src, CacheStatus::Read); + if (is_offload_unique(stmt->src) && + is_operand_loop_invariant(stmt->src, stmt->parent)) { + cache_global_to_local(stmt, stmt->src, CacheStatus::Read); } } void visit(GlobalStoreStmt *stmt) override { - if (is_offload_unique(stmt->dest) && is_operand_loop_invariant(stmt->dest, stmt->parent)) { - cache_global_to_local(stmt, - stmt->dest, CacheStatus::Write); - + if (is_offload_unique(stmt->dest) && + is_operand_loop_invariant(stmt->dest, stmt->parent)) { + cache_global_to_local(stmt, stmt->dest, CacheStatus::Write); } } diff --git a/taichi/transforms/loop_invariant_code_motion.cpp b/taichi/transforms/loop_invariant_code_motion.cpp index 3a50421f19d99..8e9fc4c4f9a11 100644 --- a/taichi/transforms/loop_invariant_code_motion.cpp +++ b/taichi/transforms/loop_invariant_code_motion.cpp @@ -33,7 +33,8 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { } void visit(GlobalPtrStmt *stmt) override { - if (config.cache_loop_invariant_global_vars && is_loop_invariant(stmt, stmt->parent)) { + if (config.cache_loop_invariant_global_vars && + is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); @@ -43,7 +44,8 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { } void visit(ExternalPtrStmt *stmt) override { - if (config.cache_loop_invariant_global_vars && is_loop_invariant(stmt, stmt->parent)) { + if (config.cache_loop_invariant_global_vars && + is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); @@ -53,7 +55,8 @@ class LoopInvariantCodeMotion : public LoopInvariantDetector { } void visit(ArgLoadStmt *stmt) override { - if (config.cache_loop_invariant_global_vars && is_loop_invariant(stmt, stmt->parent)) { + if (config.cache_loop_invariant_global_vars && + is_loop_invariant(stmt, stmt->parent)) { auto replacement = stmt->clone(); stmt->replace_usages_with(replacement.get()); From cb9a4f0b889222e8f433987b37692fcac84c0457 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Thu, 22 Sep 2022 15:38:43 +0800 Subject: [PATCH 13/16] default on --- taichi/program/compile_config.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index f2903114ec2ba..8868f7a642700 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -28,7 +28,7 @@ struct CompileConfig { bool lower_access; bool simplify_after_lower_access; bool move_loop_invariant_outside_if; - bool cache_loop_invariant_global_vars{false}; + bool cache_loop_invariant_global_vars{true}; bool demote_dense_struct_fors; bool advanced_optimization; bool constant_folding; From 0b94539e6d6996c4babcee0f3ae2e388b7b3dc13 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Fri, 23 Sep 2022 11:01:52 +0800 Subject: [PATCH 14/16] simplify --- taichi/transforms/cache_loop_invariant_global_vars.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index 5f45ae1aa521b..c895f25d6ffc2 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -12,7 +12,6 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { Read = 1, Write = 2, ReadWrite = 3, - HasAtomic = 4 }; typedef std::unordered_map(dest->ret_type.ptr_removed()); auto alloca_stmt = alloca_unique.get(); From c01ccef1bfd0d9237121a602e16eed2742481401 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Fri, 23 Sep 2022 11:31:53 +0800 Subject: [PATCH 15/16] simplify --- .../cache_loop_invariant_global_vars.cpp | 77 ++++++++----------- taichi/transforms/loop_invariant_detector.h | 6 +- 2 files changed, 33 insertions(+), 50 deletions(-) diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index c895f25d6ffc2..18a507ce0e226 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -1,7 +1,7 @@ #include "taichi/transforms/loop_invariant_detector.h" #include "taichi/ir/analysis.h" -TLANG_NAMESPACE_BEGIN +namespace taichi::lang { class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { public: @@ -14,8 +14,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { ReadWrite = 3, }; - typedef std::unordered_map>> + typedef std::unordered_map> CacheMap; std::stack cached_maps; @@ -89,41 +88,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { void visit_loop(Block *body) override { cached_maps.emplace(); - - loop_blocks.push(body); - - body->accept(this); - - for (auto &[dest, status_vec] : cached_maps.top()) { - auto &[status, vec] = status_vec; - auto alloca_unique = - std::make_unique(dest->ret_type.ptr_removed()); - auto alloca_stmt = alloca_unique.get(); - modifier.insert_before(body->parent_stmt, std::move(alloca_unique)); - if (int(status) & int(CacheStatus::Read)) { - set_init_value(alloca_stmt, dest); - } - if (int(status) & int(CacheStatus::Write)) { - add_writeback(alloca_stmt, dest); - } - for (auto *stmt : vec) { - if (stmt->is()) { - auto local_load = std::make_unique(alloca_stmt); - stmt->replace_usages_with(local_load.get()); - modifier.insert_before(stmt, std::move(local_load)); - modifier.erase(stmt); - } else if (auto *global_store = stmt->cast()) { - auto local_store = - std::make_unique(alloca_stmt, global_store->val); - stmt->replace_usages_with(local_store.get()); - modifier.insert_before(stmt, std::move(local_store)); - modifier.erase(stmt); - } else { - TI_UNREACHABLE - } - } - } - loop_blocks.pop(); + LoopInvariantDetector::visit_loop(body); cached_maps.pop(); } @@ -143,30 +108,49 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { modifier.insert_before(current_loop_stmt(), std::move(local_store)); } - void cache_global_to_local(Stmt *stmt, Stmt *dest, CacheStatus status) { - if (auto &[cached_status, vec] = cached_maps.top()[dest]; + AllocaStmt *cache_global_to_local(Stmt *dest, CacheStatus status) { + if (auto &[cached_status, alloca_stmt] = cached_maps.top()[dest]; cached_status != CacheStatus::None) { // The global variable has already been cached. if (cached_status == CacheStatus::Read && status == CacheStatus::Write) { + add_writeback(alloca_stmt, dest); cached_status = CacheStatus::ReadWrite; } - vec.push_back(stmt); - return; + return alloca_stmt; } - cached_maps.top()[dest] = {status, {stmt}}; + auto alloca_unique = + std::make_unique(dest->ret_type.ptr_removed()); + auto alloca_stmt = alloca_unique.get(); + modifier.insert_before(current_loop_stmt(), std::move(alloca_unique)); + if (status == CacheStatus::Read) { + set_init_value(alloca_stmt, dest); + } else if (status == CacheStatus::Write) { + add_writeback(alloca_stmt, dest); + } + cached_maps.top()[dest] = {status, alloca_stmt}; + return alloca_stmt; } void visit(GlobalLoadStmt *stmt) override { if (is_offload_unique(stmt->src) && is_operand_loop_invariant(stmt->src, stmt->parent)) { - cache_global_to_local(stmt, stmt->src, CacheStatus::Read); + auto alloca_stmt = cache_global_to_local(stmt->src, CacheStatus::Read); + auto local_load = std::make_unique(alloca_stmt); + stmt->replace_usages_with(local_load.get()); + modifier.insert_before(stmt, std::move(local_load)); + modifier.erase(stmt); } } void visit(GlobalStoreStmt *stmt) override { if (is_offload_unique(stmt->dest) && is_operand_loop_invariant(stmt->dest, stmt->parent)) { - cache_global_to_local(stmt, stmt->dest, CacheStatus::Write); + auto alloca_stmt = cache_global_to_local(stmt->dest, CacheStatus::Write); + auto local_store = + std::make_unique(alloca_stmt, stmt->val); + stmt->replace_usages_with(local_store.get()); + modifier.insert_before(stmt, std::move(local_store)); + modifier.erase(stmt); } } @@ -193,5 +177,4 @@ bool cache_loop_invariant_global_vars(IRNode *root, return CacheLoopInvariantGlobalVars::run(root, config); } } // namespace irpass - -TLANG_NAMESPACE_END +} // namespace taichi::lang \ No newline at end of file diff --git a/taichi/transforms/loop_invariant_detector.h b/taichi/transforms/loop_invariant_detector.h index 0591e39df6ede..beb5c5b80d34e 100644 --- a/taichi/transforms/loop_invariant_detector.h +++ b/taichi/transforms/loop_invariant_detector.h @@ -6,7 +6,7 @@ #include -TLANG_NAMESPACE_BEGIN +namespace taichi::lang { class LoopInvariantDetector : public BasicStmtVisitor { public: @@ -33,7 +33,7 @@ class LoopInvariantDetector : public BasicStmtVisitor { // a = x[0] // if b: // c = a + 1 - // Since we are moving statements outside the cloest for scope, + // Since we are moving statements outside the closest for scope, // We need to check the scope of the operand Stmt *operand_parent = operand; while (operand_parent->parent) { @@ -131,4 +131,4 @@ class LoopInvariantDetector : public BasicStmtVisitor { } }; -TLANG_NAMESPACE_END +} // namespace taichi::lang \ No newline at end of file From 02edef4145fb09de5a209c1b27cc5f7ebb38f89a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Sep 2022 03:35:56 +0000 Subject: [PATCH 16/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/transforms/cache_loop_invariant_global_vars.cpp | 2 +- taichi/transforms/loop_invariant_detector.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index 18a507ce0e226..5310e5d94c8a0 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -177,4 +177,4 @@ bool cache_loop_invariant_global_vars(IRNode *root, return CacheLoopInvariantGlobalVars::run(root, config); } } // namespace irpass -} // namespace taichi::lang \ No newline at end of file +} // namespace taichi::lang diff --git a/taichi/transforms/loop_invariant_detector.h b/taichi/transforms/loop_invariant_detector.h index beb5c5b80d34e..d9ce963c505ca 100644 --- a/taichi/transforms/loop_invariant_detector.h +++ b/taichi/transforms/loop_invariant_detector.h @@ -131,4 +131,4 @@ class LoopInvariantDetector : public BasicStmtVisitor { } }; -} // namespace taichi::lang \ No newline at end of file +} // namespace taichi::lang