From 50ad29aeba14e1738a1f26996d62f7661d66a652 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 23 Apr 2020 21:48:07 -0400 Subject: [PATCH] [opt] Eliminate useless local stores and atomics (#858) * eliminate useless local stores * handle IfStmt properly * fix test * eliminate atomics * add irpass::analysis::gather_used_atomics for better eliminating atomics * [skip ci] enforce code format * [skip ci] avoid confusing "for" * add a test Co-authored-by: Taichi Gardener --- taichi/analysis/gather_used_atomics.cpp | 45 ++++++ taichi/ir/ir.h | 1 + taichi/transforms/optimize_local_variable.cpp | 143 ++++++++++++++---- tests/python/test_optimization.py | 26 ++++ 4 files changed, 183 insertions(+), 32 deletions(-) create mode 100644 taichi/analysis/gather_used_atomics.cpp diff --git a/taichi/analysis/gather_used_atomics.cpp b/taichi/analysis/gather_used_atomics.cpp new file mode 100644 index 0000000000000..6eb1b147e790c --- /dev/null +++ b/taichi/analysis/gather_used_atomics.cpp @@ -0,0 +1,45 @@ +#include "taichi/ir/ir.h" +#include + +TLANG_NAMESPACE_BEGIN + +class UsedAtomicsSearcher : public BasicStmtVisitor { + private: + std::unordered_set used_atomics; + + public: + UsedAtomicsSearcher() { + allow_undefined_visitor = true; + invoke_default_visitor = true; + } + + void search_operands(Stmt *stmt) { + for (auto &op : stmt->get_operands()) { + if (op != nullptr && op->is()) { + used_atomics.insert(op->as()); + } + } + } + + void preprocess_container_stmt(Stmt *stmt) override { + search_operands(stmt); + } + + void visit(Stmt *stmt) override { + search_operands(stmt); + } + + static std::unordered_set run(IRNode *root) { + UsedAtomicsSearcher searcher; + root->accept(&searcher); + return searcher.used_atomics; + } +}; + +namespace irpass::analysis { +std::unordered_set gather_used_atomics(IRNode *root) { + return UsedAtomicsSearcher::run(root); +} +} // namespace irpass::analysis + +TLANG_NAMESPACE_END diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index cce7c0cbb7119..5fca8396fa1ad 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -128,6 +128,7 @@ std::unordered_set detect_loops_with_continue(IRNode *root); std::unordered_set gather_deactivations(IRNode *root); std::vector gather_statements(IRNode *root, const std::function &test); +std::unordered_set gather_used_atomics(IRNode *root); bool has_load_or_atomic(IRNode *root, Stmt *var); bool has_store_or_atomic(IRNode *root, const std::vector &vars); std::pair last_store_or_atomic(IRNode *root, Stmt *var); diff --git a/taichi/transforms/optimize_local_variable.cpp b/taichi/transforms/optimize_local_variable.cpp index 7251ff9a7f646..f66b24580c10f 100644 --- a/taichi/transforms/optimize_local_variable.cpp +++ b/taichi/transforms/optimize_local_variable.cpp @@ -7,6 +7,7 @@ TLANG_NAMESPACE_BEGIN class AllocaOptimize : public IRVisitor { private: AllocaStmt *alloca_stmt; + std::unordered_set *used_atomics; public: // If neither stored nor loaded (nor used as operands in masks/loop_vars), @@ -22,8 +23,6 @@ class AllocaOptimize : public IRVisitor { bool last_store_valid; // last_store_loaded: Is the last store ever loaded? If not, eliminate it. - // If stored is false, last_store_loaded means if the alloca is ever loaded, - // but it should not be used. bool last_store_loaded; AtomicOpStmt *last_atomic; @@ -45,8 +44,17 @@ class AllocaOptimize : public IRVisitor { }; IsInsideLoop is_inside_loop; - explicit AllocaOptimize(AllocaStmt *alloca_stmt) + // Is this alloca ever stored in the current block? + bool stored_in_current_block; + // Is this alloca ever loaded before the first store in the current block? + // If this block is a branch of IfStmt, we will use this variable to + // determine whether we can eliminate the last store before the IfStmt. + bool loaded_before_first_store_in_current_block; + + explicit AllocaOptimize(AllocaStmt *alloca_stmt, + std::unordered_set *used_atomics) : alloca_stmt(alloca_stmt), + used_atomics(used_atomics), stored(false), loaded(false), last_store(nullptr), @@ -54,7 +62,9 @@ class AllocaOptimize : public IRVisitor { last_store_loaded(false), last_atomic(nullptr), last_atomic_eliminable(false), - is_inside_loop(outside_loop) { + is_inside_loop(outside_loop), + stored_in_current_block(false), + loaded_before_first_store_in_current_block(false) { allow_undefined_visitor = true; invoke_default_visitor = true; } @@ -68,24 +78,39 @@ class AllocaOptimize : public IRVisitor { void visit(AtomicOpStmt *stmt) override { if (stmt->dest != alloca_stmt) return; + // This statement is loading the last store, so we can't eliminate it. stored = true; loaded = true; last_store = nullptr; last_store_valid = false; last_store_loaded = false; last_atomic = stmt; - last_atomic_eliminable = true; + last_atomic_eliminable = used_atomics->find(stmt) == used_atomics->end(); + if (!stored_in_current_block) + loaded_before_first_store_in_current_block = true; + stored_in_current_block = true; } void visit(LocalStoreStmt *stmt) override { if (stmt->ptr != alloca_stmt) return; + if (last_store && !last_store_loaded) { + // The last store is never loaded. + last_store->parent->erase(last_store); + throw IRModified(); + } + if (last_atomic && last_atomic_eliminable) { + // The last AtomicOpStmt is never used. + last_atomic->parent->erase(last_atomic); + throw IRModified(); + } stored = true; last_store = stmt; last_store_valid = true; last_store_loaded = false; last_atomic = nullptr; last_atomic_eliminable = false; + stored_in_current_block = true; } void visit(LocalLoadStmt *stmt) override { @@ -100,6 +125,8 @@ class AllocaOptimize : public IRVisitor { last_store_loaded = true; if (last_atomic) last_atomic_eliminable = false; + if (!stored_in_current_block) + loaded_before_first_store_in_current_block = true; } } if (!regular) @@ -121,13 +148,26 @@ class AllocaOptimize : public IRVisitor { } } + AllocaOptimize new_instance_if_stmt() const { + AllocaOptimize new_instance = *this; + // Avoid eliminating the last store and the last AtomicOpStmt. + if (last_store) + new_instance.last_store_loaded = true; + if (last_atomic) + new_instance.last_atomic_eliminable = false; + + new_instance.stored_in_current_block = false; + new_instance.loaded_before_first_store_in_current_block = false; + return new_instance; + } + void visit(IfStmt *if_stmt) override { TI_ASSERT(if_stmt->true_mask == nullptr); TI_ASSERT(if_stmt->false_mask == nullptr); // Create two new instances for IfStmt - AllocaOptimize true_branch = *this; - AllocaOptimize false_branch = *this; + AllocaOptimize true_branch = new_instance_if_stmt(); + AllocaOptimize false_branch = new_instance_if_stmt(); if (if_stmt->true_statements) { if_stmt->true_statements->accept(&true_branch); } @@ -135,6 +175,25 @@ class AllocaOptimize : public IRVisitor { if_stmt->false_statements->accept(&false_branch); } + if (last_store && !last_store_loaded && + true_branch.stored_in_current_block && + !true_branch.loaded_before_first_store_in_current_block && + false_branch.stored_in_current_block && + !false_branch.loaded_before_first_store_in_current_block) { + // The last store before the IfStmt is never loaded. + last_store->parent->erase(last_store); + throw IRModified(); + } + if (last_atomic && last_atomic_eliminable && + true_branch.stored_in_current_block && + !true_branch.loaded_before_first_store_in_current_block && + false_branch.stored_in_current_block && + !false_branch.loaded_before_first_store_in_current_block) { + // The last AtomicOpStmt is never used. + last_atomic->parent->erase(last_atomic); + throw IRModified(); + } + stored = true_branch.stored || false_branch.stored; loaded = true_branch.loaded || false_branch.loaded; @@ -145,9 +204,12 @@ class AllocaOptimize : public IRVisitor { TI_ASSERT(true_branch.last_store != nullptr); last_store_valid = true; if (last_store == true_branch.last_store) { - last_store_loaded = last_store_loaded || - true_branch.last_store_loaded || - false_branch.last_store_loaded; + TI_ASSERT(!true_branch.stored_in_current_block); + TI_ASSERT(!false_branch.stored_in_current_block); + last_store_loaded = + last_store_loaded || + true_branch.loaded_before_first_store_in_current_block || + false_branch.loaded_before_first_store_in_current_block; } else { last_store = true_branch.last_store; last_store_loaded = @@ -159,11 +221,18 @@ class AllocaOptimize : public IRVisitor { if (true_branch.last_store == last_store && false_branch.last_store == last_store) { // The last store didn't change. - last_store_loaded = last_store_loaded || - true_branch.last_store_loaded || - false_branch.last_store_loaded; + TI_ASSERT(!true_branch.stored_in_current_block); + TI_ASSERT(!false_branch.stored_in_current_block); + last_store_loaded = + last_store_loaded || + true_branch.loaded_before_first_store_in_current_block || + false_branch.loaded_before_first_store_in_current_block; } else { - // The last store changed, so we can't eliminate last_store. + // The last store changed. + bool current_eliminable = + last_store && !last_store_loaded && + !true_branch.loaded_before_first_store_in_current_block && + !false_branch.loaded_before_first_store_in_current_block; bool true_eliminable = true_branch.last_store != last_store && true_branch.last_store != nullptr && !true_branch.last_store_loaded; @@ -176,6 +245,10 @@ class AllocaOptimize : public IRVisitor { } else if (false_eliminable) { last_store = false_branch.last_store; last_store_loaded = false; + } else if (current_eliminable) { + TI_ASSERT(!true_branch.stored_in_current_block || + !false_branch.stored_in_current_block); + last_store_loaded = false; } else { // Neither branch provides a eliminable local store. last_store = nullptr; @@ -187,11 +260,16 @@ class AllocaOptimize : public IRVisitor { if (true_branch.last_atomic == last_atomic && false_branch.last_atomic == last_atomic) { // The last AtomicOpStmt didn't change. - last_atomic_eliminable = last_atomic_eliminable && - true_branch.last_atomic_eliminable && - false_branch.last_atomic_eliminable; + last_atomic_eliminable = + last_atomic_eliminable && + !true_branch.loaded_before_first_store_in_current_block && + !false_branch.loaded_before_first_store_in_current_block; } else { - // The last AtomicOpStmt changed, so we can't eliminate last_atomic. + // The last AtomicOpStmt changed. + bool current_eliminable = + last_atomic && last_atomic_eliminable && + !true_branch.loaded_before_first_store_in_current_block && + !false_branch.loaded_before_first_store_in_current_block; bool true_eliminable = true_branch.last_atomic != last_atomic && true_branch.last_atomic != nullptr && true_branch.last_atomic_eliminable; @@ -204,6 +282,10 @@ class AllocaOptimize : public IRVisitor { } else if (false_eliminable) { last_atomic = false_branch.last_atomic; last_atomic_eliminable = true; + } else if (current_eliminable) { + TI_ASSERT(!true_branch.stored_in_current_block || + !false_branch.stored_in_current_block); + last_atomic_eliminable = true; } else { // Neither branch provides a eliminable AtomicOpStmt. last_atomic = nullptr; @@ -230,7 +312,7 @@ class AllocaOptimize : public IRVisitor { body->accept(this); return; } - AllocaOptimize loop(alloca_stmt); + AllocaOptimize loop(alloca_stmt, used_atomics); loop.is_inside_loop = inside_loop_may_have_stores; body->accept(&loop); @@ -312,17 +394,11 @@ class AllocaOptimize : public IRVisitor { throw IRModified(); } if (last_atomic && last_atomic_eliminable) { - // The last AtomicOpStmt is never loaded. + // The last AtomicOpStmt is never used. // last_atomic_valid == false means that it's in an IfStmt. - if (irpass::analysis::gather_statements( - block, - [&](Stmt *stmt) { return stmt->have_operand(last_atomic); }) - .empty()) { - // The last AtomicOpStmt is never used. - // Eliminate the last AtomicOpStmt. - last_atomic->parent->erase(last_atomic); - throw IRModified(); - } + // Eliminate the last AtomicOpStmt. + last_atomic->parent->erase(last_atomic); + throw IRModified(); } if (!stored && !loaded) { // Never stored and never loaded. @@ -338,11 +414,13 @@ class AllocaOptimize : public IRVisitor { class AllocaFindAndOptimize : public BasicStmtVisitor { private: std::unordered_set visited; + std::unordered_set *used_atomics; public: using BasicStmtVisitor::visit; - AllocaFindAndOptimize() : visited() { + AllocaFindAndOptimize(std::unordered_set *used_atomics) + : visited(), used_atomics(used_atomics) { allow_undefined_visitor = true; invoke_default_visitor = true; } @@ -358,13 +436,14 @@ class AllocaFindAndOptimize : public BasicStmtVisitor { void visit(AllocaStmt *alloca_stmt) override { if (is_done(alloca_stmt)) return; - AllocaOptimize optimizer(alloca_stmt); + AllocaOptimize optimizer(alloca_stmt, used_atomics); optimizer.run(); set_done(alloca_stmt); } static void run(IRNode *node) { - AllocaFindAndOptimize find_and_optimizer; + auto used_atomics = irpass::analysis::gather_used_atomics(node); + AllocaFindAndOptimize find_and_optimizer(&used_atomics); while (true) { bool modified = false; try { diff --git a/tests/python/test_optimization.py b/tests/python/test_optimization.py index d9b04a5851c89..14fa238368210 100644 --- a/tests/python/test_optimization.py +++ b/tests/python/test_optimization.py @@ -21,3 +21,29 @@ def func(): val[None] = 10 func() assert val[None] == 10 + + +@ti.all_archs +def test_advanced_unused_store_elimination_if(): + val = ti.var(ti.i32) + ti.root.place(val) + + @ti.kernel + def func(): + a = 1 + if val[None]: + a = 2 + if val[None]: + a = 3 + else: + a = 4 + val[None] = a + else: + val[None] = a + + + val[None] = 0 + func() + assert val[None] == 1 + func() + assert val[None] == 3