Skip to content

Commit

Permalink
[Opt] [refactor] Move unreachable code elimination to a separate pass (
Browse files Browse the repository at this point in the history
…taichi-dev#1315)

* [Opt] [refactor] Move unreachable code elimination to a separate pass

* improve

* minor

* [skip ci] minor

* Update taichi/transforms/continue_stmt_optimization.cpp

Co-authored-by: Yuanming Hu <yuanming-hu@users.noreply.github.com>

* add optimizations for if (0) and if (1)

Co-authored-by: Yuanming Hu <yuanming-hu@users.noreply.github.com>
  • Loading branch information
2 people authored and Rullec committed Jun 26, 2020
1 parent db4e602 commit f40d73e
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 7 deletions.
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ bool binary_op_simplify(IRNode *root);
bool whole_kernel_cse(IRNode *root);
void variable_optimization(IRNode *root, bool after_lower_access);
void extract_constant(IRNode *root);
bool unreachable_code_elimination(IRNode *root);
void full_simplify(IRNode *root, Kernel *kernel = nullptr);
void print(IRNode *root, std::string *output = nullptr);
void lower(IRNode *root);
Expand Down
131 changes: 131 additions & 0 deletions taichi/transforms/continue_stmt_optimization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/visitors.h"
#include "taichi/ir/transforms.h"

TLANG_NAMESPACE_BEGIN

// Unconditionally eliminate ContinueStmt's at **ends** of loops
class UselessContinueEliminator : public IRVisitor {
public:
bool modified;

UselessContinueEliminator() : modified(false) {
allow_undefined_visitor = true;
}

void visit(ContinueStmt *stmt) override {
stmt->parent->erase(stmt);
modified = true;
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->true_statements && if_stmt->true_statements->size())
if_stmt->true_statements->back()->accept(this);
if (if_stmt->false_statements && if_stmt->false_statements->size())
if_stmt->false_statements->back()->accept(this);
}
};

// Eliminate useless ContinueStmt, the statements after ContinueStmt and
// unreachable if branches
class UnreachableCodeEliminator : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;
bool modified;
UselessContinueEliminator useless_continue_eliminator;

UnreachableCodeEliminator() : modified(false) {
allow_undefined_visitor = true;
}

void visit(Block *stmt_list) override {
const int block_size = stmt_list->size();
for (int i = 0; i < block_size - 1; i++) {
if (stmt_list->statements[i]->is<ContinueStmt>()) {
// Eliminate statements after ContinueStmt
for (int j = block_size - 1; j > i; j--)
stmt_list->erase(j);
modified = true;
break;
}
}
for (auto &stmt : stmt_list->statements)
stmt->accept(this);
}

void visit_loop(Block *body) {
if (body->size())
body->back()->accept(&useless_continue_eliminator);
body->accept(this);
}

void visit(RangeForStmt *stmt) override {
visit_loop(stmt->body.get());
}

void visit(StructForStmt *stmt) override {
visit_loop(stmt->body.get());
}

void visit(WhileStmt *stmt) override {
visit_loop(stmt->body.get());
}

void visit(OffloadedStmt *stmt) override {
if (stmt->prologue)
stmt->prologue->accept(this);
if (stmt->task_type == OffloadedStmt::TaskType::range_for ||
stmt->task_type == OffloadedStmt::TaskType::struct_for)
visit_loop(stmt->body.get());
else if (stmt->body)
stmt->body->accept(this);
if (stmt->epilogue)
stmt->epilogue->accept(this);
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->cond->is<ConstStmt>() && if_stmt->cond->width() == 1) {
if (if_stmt->cond->as<ConstStmt>()->val[0].equal_value(0)) {
// if (0)
if (if_stmt->true_statements) {
if_stmt->true_statements = nullptr;
modified = true;
}
} else {
// if (1)
if (if_stmt->false_statements) {
if_stmt->false_statements = nullptr;
modified = true;
}
}
}
if (if_stmt->true_statements)
if_stmt->true_statements->accept(this);
if (if_stmt->false_statements)
if_stmt->false_statements->accept(this);
}

static bool run(IRNode *node) {
bool modified = false;
while (true) {
UnreachableCodeEliminator eliminator;
node->accept(&eliminator);
if (eliminator.modified ||
eliminator.useless_continue_eliminator.modified) {
modified = true;
} else {
break;
}
}
return modified;
}
};

namespace irpass {
bool unreachable_code_elimination(IRNode *root) {
TI_AUTO_PROF;
return UnreachableCodeEliminator::run(root);
}
} // namespace irpass

TLANG_NAMESPACE_END
10 changes: 3 additions & 7 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,13 +961,7 @@ class BasicBlockSimplify : public IRVisitor {
}

void visit(ContinueStmt *stmt) override {
if (stmt != stmt->parent->back()) {
const int location = stmt->parent->locate(stmt);
while (location + 1 < (int)stmt->parent->size()) {
stmt->parent->erase(location + 1);
}
throw IRModified();
}
return;
}

static bool is_global_write(Stmt *stmt) {
Expand Down Expand Up @@ -1225,6 +1219,8 @@ void full_simplify(IRNode *root, Kernel *kernel) {
while (true) {
bool modified = false;
extract_constant(root);
if (unreachable_code_elimination(root))
modified = true;
if (binary_op_simplify(root))
modified = true;
if (constant_fold(root))
Expand Down

0 comments on commit f40d73e

Please sign in to comment.