Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Opt] [refactor] Move unreachable code elimination to a separate pass #1315

Merged
merged 6 commits into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ bool binary_op_simplify(IRNode *root);
bool whole_kernel_cse(IRNode *root);
void variable_optimization(IRNode *root, bool after_lower_access);
void extract_constant(IRNode *root);
bool unreachable_code_elimination(IRNode *root);
void full_simplify(IRNode *root, Kernel *kernel = nullptr);
void print(IRNode *root, std::string *output = nullptr);
void lower(IRNode *root);
Expand Down
131 changes: 131 additions & 0 deletions taichi/transforms/continue_stmt_optimization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/visitors.h"
#include "taichi/ir/transforms.h"

TLANG_NAMESPACE_BEGIN

// Unconditionally eliminate ContinueStmt's at **ends** of loops
class UselessContinueEliminator : public IRVisitor {
public:
bool modified;

UselessContinueEliminator() : modified(false) {
allow_undefined_visitor = true;
}

void visit(ContinueStmt *stmt) override {
stmt->parent->erase(stmt);
modified = true;
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->true_statements && if_stmt->true_statements->size())
if_stmt->true_statements->back()->accept(this);
if (if_stmt->false_statements && if_stmt->false_statements->size())
if_stmt->false_statements->back()->accept(this);
}
};

// Eliminate useless ContinueStmt, the statements after ContinueStmt and
// unreachable if branches
class UnreachableCodeEliminator : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;
bool modified;
UselessContinueEliminator useless_continue_eliminator;

UnreachableCodeEliminator() : modified(false) {
allow_undefined_visitor = true;
}

void visit(Block *stmt_list) override {
const int block_size = stmt_list->size();
for (int i = 0; i < block_size - 1; i++) {
if (stmt_list->statements[i]->is<ContinueStmt>()) {
// Eliminate statements after ContinueStmt
for (int j = block_size - 1; j > i; j--)
stmt_list->erase(j);
modified = true;
break;
}
}
for (auto &stmt : stmt_list->statements)
stmt->accept(this);
}

void visit_loop(Block *body) {
if (body->size())
body->back()->accept(&useless_continue_eliminator);
body->accept(this);
}

void visit(RangeForStmt *stmt) override {
visit_loop(stmt->body.get());
}

void visit(StructForStmt *stmt) override {
visit_loop(stmt->body.get());
}

void visit(WhileStmt *stmt) override {
visit_loop(stmt->body.get());
}

void visit(OffloadedStmt *stmt) override {
if (stmt->prologue)
stmt->prologue->accept(this);
if (stmt->task_type == OffloadedStmt::TaskType::range_for ||
stmt->task_type == OffloadedStmt::TaskType::struct_for)
visit_loop(stmt->body.get());
else if (stmt->body)
stmt->body->accept(this);
if (stmt->epilogue)
stmt->epilogue->accept(this);
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->cond->is<ConstStmt>() && if_stmt->cond->width() == 1) {
if (if_stmt->cond->as<ConstStmt>()->val[0].equal_value(0)) {
// if (0)
if (if_stmt->true_statements) {
if_stmt->true_statements = nullptr;
modified = true;
}
} else {
// if (1)
if (if_stmt->false_statements) {
if_stmt->false_statements = nullptr;
modified = true;
}
}
}
if (if_stmt->true_statements)
if_stmt->true_statements->accept(this);
if (if_stmt->false_statements)
if_stmt->false_statements->accept(this);
}

static bool run(IRNode *node) {
bool modified = false;
while (true) {
UnreachableCodeEliminator eliminator;
node->accept(&eliminator);
if (eliminator.modified ||
eliminator.useless_continue_eliminator.modified) {
modified = true;
} else {
break;
}
}
return modified;
}
};

namespace irpass {
bool unreachable_code_elimination(IRNode *root) {
TI_AUTO_PROF;
return UnreachableCodeEliminator::run(root);
}
} // namespace irpass

TLANG_NAMESPACE_END
10 changes: 3 additions & 7 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,13 +961,7 @@ class BasicBlockSimplify : public IRVisitor {
}

void visit(ContinueStmt *stmt) override {
if (stmt != stmt->parent->back()) {
const int location = stmt->parent->locate(stmt);
while (location + 1 < (int)stmt->parent->size()) {
stmt->parent->erase(location + 1);
}
throw IRModified();
}
return;
}

static bool is_global_write(Stmt *stmt) {
Expand Down Expand Up @@ -1225,6 +1219,8 @@ void full_simplify(IRNode *root, Kernel *kernel) {
while (true) {
bool modified = false;
extract_constant(root);
if (unreachable_code_elimination(root))
modified = true;
if (binary_op_simplify(root))
modified = true;
if (constant_fold(root))
Expand Down