From de3f482464ec2833e8851f42eb3609952eb2220d Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 23 Jun 2020 19:03:18 -0400 Subject: [PATCH 1/6] [Opt] [refactor] Move unreachable code elimination to a separate pass --- taichi/ir/transforms.h | 1 + .../transforms/continue_stmt_optimization.cpp | 95 +++++++++++++++++++ taichi/transforms/simplify.cpp | 10 +- 3 files changed, 99 insertions(+), 7 deletions(-) create mode 100644 taichi/transforms/continue_stmt_optimization.cpp diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index d1aaf8ec43eff..e67c264d2a2c7 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -25,6 +25,7 @@ bool binary_op_simplify(IRNode *root); bool whole_kernel_cse(IRNode *root); void variable_optimization(IRNode *root, bool after_lower_access); void extract_constant(IRNode *root); +bool continue_stmt_optimization(IRNode *root); void full_simplify(IRNode *root, Kernel *kernel = nullptr); void print(IRNode *root, std::string *output = nullptr); void lower(IRNode *root); diff --git a/taichi/transforms/continue_stmt_optimization.cpp b/taichi/transforms/continue_stmt_optimization.cpp new file mode 100644 index 0000000000000..658652a0aa0ba --- /dev/null +++ b/taichi/transforms/continue_stmt_optimization.cpp @@ -0,0 +1,95 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/visitors.h" +#include "taichi/ir/transforms.h" + +TLANG_NAMESPACE_BEGIN + +// Eliminate useless ContinueStmt +class UselessContinueEliminator : public IRVisitor { + public: + bool modified; + + UselessContinueEliminator() : modified(false) { + allow_undefined_visitor = true; + } + + void visit(ContinueStmt *stmt) override { + stmt->parent->erase(stmt); + modified = true; + } + + void visit(IfStmt *if_stmt) override { + if (if_stmt->true_statements && if_stmt->true_statements->size()) + if_stmt->true_statements->back()->accept(this); + if (if_stmt->false_statements && if_stmt->false_statements->size()) + if_stmt->false_statements->back()->accept(this); + } +}; + +// Eliminate useless ContinueStmt and the statements after ContinueStmt +class ContinueStmtOptimizer : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + bool modified; + UselessContinueEliminator useless_continue_eliminator; + + ContinueStmtOptimizer() : modified(false) { + allow_undefined_visitor = true; + } + + void visit_loop(Stmt *loop_stmt, Block *body) { + const int body_size = body->size(); + for (int i = 0; i < body_size - 1; i++) { + if (auto continue_stmt = body->statements[i]->cast()) { + TI_ASSERT(continue_stmt->scope == loop_stmt || + continue_stmt->scope == nullptr); + // Eliminate statements after ContinueStmt + for (int j = body_size - 1; j > i; j--) + body->erase(j); + modified = true; + } + } + if (body->size()) + body->back()->accept(&useless_continue_eliminator); + body->accept(this); + } + + void visit(RangeForStmt *stmt) override { + visit_loop(stmt, stmt->body.get()); + } + + void visit(StructForStmt *stmt) override { + visit_loop(stmt, stmt->body.get()); + } + + void visit(WhileStmt *stmt) override { + visit_loop(stmt, stmt->body.get()); + } + + void visit(OffloadedStmt *stmt) override { + if (stmt->prologue) + stmt->prologue->accept(this); + if (stmt->task_type == OffloadedStmt::TaskType::range_for || + stmt->task_type == OffloadedStmt::TaskType::struct_for) + visit_loop(stmt, stmt->body.get()); + else if (stmt->body) + stmt->body->accept(this); + if (stmt->epilogue) + stmt->epilogue->accept(this); + } + + static bool run(IRNode *node) { + ContinueStmtOptimizer optimizer; + node->accept(&optimizer); + return optimizer.modified || optimizer.useless_continue_eliminator.modified; + } +}; + +namespace irpass { +bool continue_stmt_optimization(IRNode *root) { + TI_AUTO_PROF; + return ContinueStmtOptimizer::run(root); +} +} // namespace irpass + +TLANG_NAMESPACE_END diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 2c7262e171906..64d6979ca0eac 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -961,13 +961,7 @@ class BasicBlockSimplify : public IRVisitor { } void visit(ContinueStmt *stmt) override { - if (stmt != stmt->parent->back()) { - const int location = stmt->parent->locate(stmt); - while (location + 1 < (int)stmt->parent->size()) { - stmt->parent->erase(location + 1); - } - throw IRModified(); - } + return; } static bool is_global_write(Stmt *stmt) { @@ -1225,6 +1219,8 @@ void full_simplify(IRNode *root, Kernel *kernel) { while (true) { bool modified = false; extract_constant(root); + if (continue_stmt_optimization(root)) + modified = true; if (binary_op_simplify(root)) modified = true; if (constant_fold(root)) From 33b4d5d714316de57d07f1c178b737410e990275 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 23 Jun 2020 19:20:42 -0400 Subject: [PATCH 2/6] improve --- .../transforms/continue_stmt_optimization.cpp | 42 ++++++++++++------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/taichi/transforms/continue_stmt_optimization.cpp b/taichi/transforms/continue_stmt_optimization.cpp index 658652a0aa0ba..b038f5bf16c39 100644 --- a/taichi/transforms/continue_stmt_optimization.cpp +++ b/taichi/transforms/continue_stmt_optimization.cpp @@ -37,33 +37,36 @@ class ContinueStmtOptimizer : public BasicStmtVisitor { allow_undefined_visitor = true; } - void visit_loop(Stmt *loop_stmt, Block *body) { - const int body_size = body->size(); - for (int i = 0; i < body_size - 1; i++) { - if (auto continue_stmt = body->statements[i]->cast()) { - TI_ASSERT(continue_stmt->scope == loop_stmt || - continue_stmt->scope == nullptr); + void visit(Block *stmt_list) override { + const int block_size = stmt_list->size(); + for (int i = 0; i < block_size - 1; i++) { + if (auto continue_stmt = stmt_list->statements[i]->cast()) { // Eliminate statements after ContinueStmt - for (int j = body_size - 1; j > i; j--) - body->erase(j); + for (int j = block_size - 1; j > i; j--) + stmt_list->erase(j); modified = true; } } + for (auto &stmt : stmt_list->statements) + stmt->accept(this); + } + + void visit_loop(Block *body) { if (body->size()) body->back()->accept(&useless_continue_eliminator); body->accept(this); } void visit(RangeForStmt *stmt) override { - visit_loop(stmt, stmt->body.get()); + visit_loop(stmt->body.get()); } void visit(StructForStmt *stmt) override { - visit_loop(stmt, stmt->body.get()); + visit_loop(stmt->body.get()); } void visit(WhileStmt *stmt) override { - visit_loop(stmt, stmt->body.get()); + visit_loop(stmt->body.get()); } void visit(OffloadedStmt *stmt) override { @@ -71,7 +74,7 @@ class ContinueStmtOptimizer : public BasicStmtVisitor { stmt->prologue->accept(this); if (stmt->task_type == OffloadedStmt::TaskType::range_for || stmt->task_type == OffloadedStmt::TaskType::struct_for) - visit_loop(stmt, stmt->body.get()); + visit_loop(stmt->body.get()); else if (stmt->body) stmt->body->accept(this); if (stmt->epilogue) @@ -79,9 +82,18 @@ class ContinueStmtOptimizer : public BasicStmtVisitor { } static bool run(IRNode *node) { - ContinueStmtOptimizer optimizer; - node->accept(&optimizer); - return optimizer.modified || optimizer.useless_continue_eliminator.modified; + bool modified = false; + while (true) { + ContinueStmtOptimizer optimizer; + node->accept(&optimizer); + if (optimizer.modified || + optimizer.useless_continue_eliminator.modified) { + modified = true; + } else { + break; + } + } + return modified; } }; From c107777184c099b0d7b10a8ce21e0626841c0a20 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 23 Jun 2020 19:22:34 -0400 Subject: [PATCH 3/6] minor --- taichi/transforms/continue_stmt_optimization.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/taichi/transforms/continue_stmt_optimization.cpp b/taichi/transforms/continue_stmt_optimization.cpp index b038f5bf16c39..c93e8d9b95a8f 100644 --- a/taichi/transforms/continue_stmt_optimization.cpp +++ b/taichi/transforms/continue_stmt_optimization.cpp @@ -45,6 +45,7 @@ class ContinueStmtOptimizer : public BasicStmtVisitor { for (int j = block_size - 1; j > i; j--) stmt_list->erase(j); modified = true; + break; } } for (auto &stmt : stmt_list->statements) From 4b1ca907b411ac65be841c73a6ea943ff61d592b Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 23 Jun 2020 19:23:10 -0400 Subject: [PATCH 4/6] [skip ci] minor --- taichi/transforms/continue_stmt_optimization.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/continue_stmt_optimization.cpp b/taichi/transforms/continue_stmt_optimization.cpp index c93e8d9b95a8f..1607d0bc972a9 100644 --- a/taichi/transforms/continue_stmt_optimization.cpp +++ b/taichi/transforms/continue_stmt_optimization.cpp @@ -40,7 +40,7 @@ class ContinueStmtOptimizer : public BasicStmtVisitor { void visit(Block *stmt_list) override { const int block_size = stmt_list->size(); for (int i = 0; i < block_size - 1; i++) { - if (auto continue_stmt = stmt_list->statements[i]->cast()) { + if (stmt_list->statements[i]->is()) { // Eliminate statements after ContinueStmt for (int j = block_size - 1; j > i; j--) stmt_list->erase(j); From 3bbf02406a7d00cbcad4ad34d7bf54df428acdae Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 23 Jun 2020 21:49:13 -0400 Subject: [PATCH 5/6] Update taichi/transforms/continue_stmt_optimization.cpp Co-authored-by: Yuanming Hu --- taichi/transforms/continue_stmt_optimization.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/continue_stmt_optimization.cpp b/taichi/transforms/continue_stmt_optimization.cpp index 1607d0bc972a9..4442bad5b9f9b 100644 --- a/taichi/transforms/continue_stmt_optimization.cpp +++ b/taichi/transforms/continue_stmt_optimization.cpp @@ -4,7 +4,7 @@ TLANG_NAMESPACE_BEGIN -// Eliminate useless ContinueStmt +// Unconditionally eliminate ContinueStmt's at **ends** of loops class UselessContinueEliminator : public IRVisitor { public: bool modified; From 06f9cfdb06fc89dc3f4f401681cfa27c8faa1077 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 23 Jun 2020 22:03:03 -0400 Subject: [PATCH 6/6] add optimizations for if (0) and if (1) --- taichi/ir/transforms.h | 2 +- .../transforms/continue_stmt_optimization.cpp | 41 +++++++++++++++---- taichi/transforms/simplify.cpp | 2 +- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index e67c264d2a2c7..c2c36126c0d13 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -25,7 +25,7 @@ bool binary_op_simplify(IRNode *root); bool whole_kernel_cse(IRNode *root); void variable_optimization(IRNode *root, bool after_lower_access); void extract_constant(IRNode *root); -bool continue_stmt_optimization(IRNode *root); +bool unreachable_code_elimination(IRNode *root); void full_simplify(IRNode *root, Kernel *kernel = nullptr); void print(IRNode *root, std::string *output = nullptr); void lower(IRNode *root); diff --git a/taichi/transforms/continue_stmt_optimization.cpp b/taichi/transforms/continue_stmt_optimization.cpp index 4442bad5b9f9b..440df20680b44 100644 --- a/taichi/transforms/continue_stmt_optimization.cpp +++ b/taichi/transforms/continue_stmt_optimization.cpp @@ -26,14 +26,15 @@ class UselessContinueEliminator : public IRVisitor { } }; -// Eliminate useless ContinueStmt and the statements after ContinueStmt -class ContinueStmtOptimizer : public BasicStmtVisitor { +// Eliminate useless ContinueStmt, the statements after ContinueStmt and +// unreachable if branches +class UnreachableCodeEliminator : public BasicStmtVisitor { public: using BasicStmtVisitor::visit; bool modified; UselessContinueEliminator useless_continue_eliminator; - ContinueStmtOptimizer() : modified(false) { + UnreachableCodeEliminator() : modified(false) { allow_undefined_visitor = true; } @@ -82,13 +83,35 @@ class ContinueStmtOptimizer : public BasicStmtVisitor { stmt->epilogue->accept(this); } + void visit(IfStmt *if_stmt) override { + if (if_stmt->cond->is() && if_stmt->cond->width() == 1) { + if (if_stmt->cond->as()->val[0].equal_value(0)) { + // if (0) + if (if_stmt->true_statements) { + if_stmt->true_statements = nullptr; + modified = true; + } + } else { + // if (1) + if (if_stmt->false_statements) { + if_stmt->false_statements = nullptr; + modified = true; + } + } + } + if (if_stmt->true_statements) + if_stmt->true_statements->accept(this); + if (if_stmt->false_statements) + if_stmt->false_statements->accept(this); + } + static bool run(IRNode *node) { bool modified = false; while (true) { - ContinueStmtOptimizer optimizer; - node->accept(&optimizer); - if (optimizer.modified || - optimizer.useless_continue_eliminator.modified) { + UnreachableCodeEliminator eliminator; + node->accept(&eliminator); + if (eliminator.modified || + eliminator.useless_continue_eliminator.modified) { modified = true; } else { break; @@ -99,9 +122,9 @@ class ContinueStmtOptimizer : public BasicStmtVisitor { }; namespace irpass { -bool continue_stmt_optimization(IRNode *root) { +bool unreachable_code_elimination(IRNode *root) { TI_AUTO_PROF; - return ContinueStmtOptimizer::run(root); + return UnreachableCodeEliminator::run(root); } } // namespace irpass diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 64d6979ca0eac..37d9cb74527f0 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -1219,7 +1219,7 @@ void full_simplify(IRNode *root, Kernel *kernel) { while (true) { bool modified = false; extract_constant(root); - if (continue_stmt_optimization(root)) + if (unreachable_code_elimination(root)) modified = true; if (binary_op_simplify(root)) modified = true;