From 384dfbedf81c6035081b917eb1ef76c84d324465 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 23 Jun 2020 23:21:23 -0400 Subject: [PATCH] [Opt] [refactor] Move unreachable code elimination to a separate pass (#1315) * [Opt] [refactor] Move unreachable code elimination to a separate pass * improve * minor * [skip ci] minor * Update taichi/transforms/continue_stmt_optimization.cpp Co-authored-by: Yuanming Hu * add optimizations for if (0) and if (1) Co-authored-by: Yuanming Hu --- taichi/ir/transforms.h | 1 + .../transforms/continue_stmt_optimization.cpp | 131 ++++++++++++++++++ taichi/transforms/simplify.cpp | 10 +- 3 files changed, 135 insertions(+), 7 deletions(-) create mode 100644 taichi/transforms/continue_stmt_optimization.cpp diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index d1aaf8ec43eff..c2c36126c0d13 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -25,6 +25,7 @@ bool binary_op_simplify(IRNode *root); bool whole_kernel_cse(IRNode *root); void variable_optimization(IRNode *root, bool after_lower_access); void extract_constant(IRNode *root); +bool unreachable_code_elimination(IRNode *root); void full_simplify(IRNode *root, Kernel *kernel = nullptr); void print(IRNode *root, std::string *output = nullptr); void lower(IRNode *root); diff --git a/taichi/transforms/continue_stmt_optimization.cpp b/taichi/transforms/continue_stmt_optimization.cpp new file mode 100644 index 0000000000000..440df20680b44 --- /dev/null +++ b/taichi/transforms/continue_stmt_optimization.cpp @@ -0,0 +1,131 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/visitors.h" +#include "taichi/ir/transforms.h" + +TLANG_NAMESPACE_BEGIN + +// Unconditionally eliminate ContinueStmt's at **ends** of loops +class UselessContinueEliminator : public IRVisitor { + public: + bool modified; + + UselessContinueEliminator() : modified(false) { + allow_undefined_visitor = true; + } + + void visit(ContinueStmt *stmt) override { + stmt->parent->erase(stmt); + modified = true; + } + + void visit(IfStmt *if_stmt) override { + if (if_stmt->true_statements && if_stmt->true_statements->size()) + if_stmt->true_statements->back()->accept(this); + if (if_stmt->false_statements && if_stmt->false_statements->size()) + if_stmt->false_statements->back()->accept(this); + } +}; + +// Eliminate useless ContinueStmt, the statements after ContinueStmt and +// unreachable if branches +class UnreachableCodeEliminator : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + bool modified; + UselessContinueEliminator useless_continue_eliminator; + + UnreachableCodeEliminator() : modified(false) { + allow_undefined_visitor = true; + } + + void visit(Block *stmt_list) override { + const int block_size = stmt_list->size(); + for (int i = 0; i < block_size - 1; i++) { + if (stmt_list->statements[i]->is()) { + // Eliminate statements after ContinueStmt + for (int j = block_size - 1; j > i; j--) + stmt_list->erase(j); + modified = true; + break; + } + } + for (auto &stmt : stmt_list->statements) + stmt->accept(this); + } + + void visit_loop(Block *body) { + if (body->size()) + body->back()->accept(&useless_continue_eliminator); + body->accept(this); + } + + void visit(RangeForStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(StructForStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(WhileStmt *stmt) override { + visit_loop(stmt->body.get()); + } + + void visit(OffloadedStmt *stmt) override { + if (stmt->prologue) + stmt->prologue->accept(this); + if (stmt->task_type == OffloadedStmt::TaskType::range_for || + stmt->task_type == OffloadedStmt::TaskType::struct_for) + visit_loop(stmt->body.get()); + else if (stmt->body) + stmt->body->accept(this); + if (stmt->epilogue) + stmt->epilogue->accept(this); + } + + void visit(IfStmt *if_stmt) override { + if (if_stmt->cond->is() && if_stmt->cond->width() == 1) { + if (if_stmt->cond->as()->val[0].equal_value(0)) { + // if (0) + if (if_stmt->true_statements) { + if_stmt->true_statements = nullptr; + modified = true; + } + } else { + // if (1) + if (if_stmt->false_statements) { + if_stmt->false_statements = nullptr; + modified = true; + } + } + } + if (if_stmt->true_statements) + if_stmt->true_statements->accept(this); + if (if_stmt->false_statements) + if_stmt->false_statements->accept(this); + } + + static bool run(IRNode *node) { + bool modified = false; + while (true) { + UnreachableCodeEliminator eliminator; + node->accept(&eliminator); + if (eliminator.modified || + eliminator.useless_continue_eliminator.modified) { + modified = true; + } else { + break; + } + } + return modified; + } +}; + +namespace irpass { +bool unreachable_code_elimination(IRNode *root) { + TI_AUTO_PROF; + return UnreachableCodeEliminator::run(root); +} +} // namespace irpass + +TLANG_NAMESPACE_END diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 2c7262e171906..37d9cb74527f0 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -961,13 +961,7 @@ class BasicBlockSimplify : public IRVisitor { } void visit(ContinueStmt *stmt) override { - if (stmt != stmt->parent->back()) { - const int location = stmt->parent->locate(stmt); - while (location + 1 < (int)stmt->parent->size()) { - stmt->parent->erase(location + 1); - } - throw IRModified(); - } + return; } static bool is_global_write(Stmt *stmt) { @@ -1225,6 +1219,8 @@ void full_simplify(IRNode *root, Kernel *kernel) { while (true) { bool modified = false; extract_constant(root); + if (unreachable_code_elimination(root)) + modified = true; if (binary_op_simplify(root)) modified = true; if (constant_fold(root))