Skip to content

Commit

Permalink
[opt] Eliminate useless local stores and atomics (#858)
Browse files Browse the repository at this point in the history
* eliminate useless local stores

* handle IfStmt properly

* fix test

* eliminate atomics

* add irpass::analysis::gather_used_atomics for better eliminating atomics

* [skip ci] enforce code format

* [skip ci] avoid confusing "for"

* add a test

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
xumingkuan and taichi-gardener authored Apr 24, 2020
1 parent 0940e9a commit 50ad29a
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 32 deletions.
45 changes: 45 additions & 0 deletions taichi/analysis/gather_used_atomics.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "taichi/ir/ir.h"
#include <unordered_set>

TLANG_NAMESPACE_BEGIN

class UsedAtomicsSearcher : public BasicStmtVisitor {
private:
std::unordered_set<AtomicOpStmt *> used_atomics;

public:
UsedAtomicsSearcher() {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}

void search_operands(Stmt *stmt) {
for (auto &op : stmt->get_operands()) {
if (op != nullptr && op->is<AtomicOpStmt>()) {
used_atomics.insert(op->as<AtomicOpStmt>());
}
}
}

void preprocess_container_stmt(Stmt *stmt) override {
search_operands(stmt);
}

void visit(Stmt *stmt) override {
search_operands(stmt);
}

static std::unordered_set<AtomicOpStmt *> run(IRNode *root) {
UsedAtomicsSearcher searcher;
root->accept(&searcher);
return searcher.used_atomics;
}
};

namespace irpass::analysis {
std::unordered_set<AtomicOpStmt *> gather_used_atomics(IRNode *root) {
return UsedAtomicsSearcher::run(root);
}
} // namespace irpass::analysis

TLANG_NAMESPACE_END
1 change: 1 addition & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ std::unordered_set<Stmt *> detect_loops_with_continue(IRNode *root);
std::unordered_set<SNode *> gather_deactivations(IRNode *root);
std::vector<Stmt *> gather_statements(IRNode *root,
const std::function<bool(Stmt *)> &test);
std::unordered_set<AtomicOpStmt *> gather_used_atomics(IRNode *root);
bool has_load_or_atomic(IRNode *root, Stmt *var);
bool has_store_or_atomic(IRNode *root, const std::vector<Stmt *> &vars);
std::pair<bool, Stmt *> last_store_or_atomic(IRNode *root, Stmt *var);
Expand Down
143 changes: 111 additions & 32 deletions taichi/transforms/optimize_local_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ TLANG_NAMESPACE_BEGIN
class AllocaOptimize : public IRVisitor {
private:
AllocaStmt *alloca_stmt;
std::unordered_set<AtomicOpStmt *> *used_atomics;

public:
// If neither stored nor loaded (nor used as operands in masks/loop_vars),
Expand All @@ -22,8 +23,6 @@ class AllocaOptimize : public IRVisitor {
bool last_store_valid;

// last_store_loaded: Is the last store ever loaded? If not, eliminate it.
// If stored is false, last_store_loaded means if the alloca is ever loaded,
// but it should not be used.
bool last_store_loaded;

AtomicOpStmt *last_atomic;
Expand All @@ -45,16 +44,27 @@ class AllocaOptimize : public IRVisitor {
};
IsInsideLoop is_inside_loop;

explicit AllocaOptimize(AllocaStmt *alloca_stmt)
// Is this alloca ever stored in the current block?
bool stored_in_current_block;
// Is this alloca ever loaded before the first store in the current block?
// If this block is a branch of IfStmt, we will use this variable to
// determine whether we can eliminate the last store before the IfStmt.
bool loaded_before_first_store_in_current_block;

explicit AllocaOptimize(AllocaStmt *alloca_stmt,
std::unordered_set<AtomicOpStmt *> *used_atomics)
: alloca_stmt(alloca_stmt),
used_atomics(used_atomics),
stored(false),
loaded(false),
last_store(nullptr),
last_store_valid(false),
last_store_loaded(false),
last_atomic(nullptr),
last_atomic_eliminable(false),
is_inside_loop(outside_loop) {
is_inside_loop(outside_loop),
stored_in_current_block(false),
loaded_before_first_store_in_current_block(false) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}
Expand All @@ -68,24 +78,39 @@ class AllocaOptimize : public IRVisitor {
void visit(AtomicOpStmt *stmt) override {
if (stmt->dest != alloca_stmt)
return;
// This statement is loading the last store, so we can't eliminate it.
stored = true;
loaded = true;
last_store = nullptr;
last_store_valid = false;
last_store_loaded = false;
last_atomic = stmt;
last_atomic_eliminable = true;
last_atomic_eliminable = used_atomics->find(stmt) == used_atomics->end();
if (!stored_in_current_block)
loaded_before_first_store_in_current_block = true;
stored_in_current_block = true;
}

void visit(LocalStoreStmt *stmt) override {
if (stmt->ptr != alloca_stmt)
return;
if (last_store && !last_store_loaded) {
// The last store is never loaded.
last_store->parent->erase(last_store);
throw IRModified();
}
if (last_atomic && last_atomic_eliminable) {
// The last AtomicOpStmt is never used.
last_atomic->parent->erase(last_atomic);
throw IRModified();
}
stored = true;
last_store = stmt;
last_store_valid = true;
last_store_loaded = false;
last_atomic = nullptr;
last_atomic_eliminable = false;
stored_in_current_block = true;
}

void visit(LocalLoadStmt *stmt) override {
Expand All @@ -100,6 +125,8 @@ class AllocaOptimize : public IRVisitor {
last_store_loaded = true;
if (last_atomic)
last_atomic_eliminable = false;
if (!stored_in_current_block)
loaded_before_first_store_in_current_block = true;
}
}
if (!regular)
Expand All @@ -121,20 +148,52 @@ class AllocaOptimize : public IRVisitor {
}
}

AllocaOptimize new_instance_if_stmt() const {
AllocaOptimize new_instance = *this;
// Avoid eliminating the last store and the last AtomicOpStmt.
if (last_store)
new_instance.last_store_loaded = true;
if (last_atomic)
new_instance.last_atomic_eliminable = false;

new_instance.stored_in_current_block = false;
new_instance.loaded_before_first_store_in_current_block = false;
return new_instance;
}

void visit(IfStmt *if_stmt) override {
TI_ASSERT(if_stmt->true_mask == nullptr);
TI_ASSERT(if_stmt->false_mask == nullptr);

// Create two new instances for IfStmt
AllocaOptimize true_branch = *this;
AllocaOptimize false_branch = *this;
AllocaOptimize true_branch = new_instance_if_stmt();
AllocaOptimize false_branch = new_instance_if_stmt();
if (if_stmt->true_statements) {
if_stmt->true_statements->accept(&true_branch);
}
if (if_stmt->false_statements) {
if_stmt->false_statements->accept(&false_branch);
}

if (last_store && !last_store_loaded &&
true_branch.stored_in_current_block &&
!true_branch.loaded_before_first_store_in_current_block &&
false_branch.stored_in_current_block &&
!false_branch.loaded_before_first_store_in_current_block) {
// The last store before the IfStmt is never loaded.
last_store->parent->erase(last_store);
throw IRModified();
}
if (last_atomic && last_atomic_eliminable &&
true_branch.stored_in_current_block &&
!true_branch.loaded_before_first_store_in_current_block &&
false_branch.stored_in_current_block &&
!false_branch.loaded_before_first_store_in_current_block) {
// The last AtomicOpStmt is never used.
last_atomic->parent->erase(last_atomic);
throw IRModified();
}

stored = true_branch.stored || false_branch.stored;
loaded = true_branch.loaded || false_branch.loaded;

Expand All @@ -145,9 +204,12 @@ class AllocaOptimize : public IRVisitor {
TI_ASSERT(true_branch.last_store != nullptr);
last_store_valid = true;
if (last_store == true_branch.last_store) {
last_store_loaded = last_store_loaded ||
true_branch.last_store_loaded ||
false_branch.last_store_loaded;
TI_ASSERT(!true_branch.stored_in_current_block);
TI_ASSERT(!false_branch.stored_in_current_block);
last_store_loaded =
last_store_loaded ||
true_branch.loaded_before_first_store_in_current_block ||
false_branch.loaded_before_first_store_in_current_block;
} else {
last_store = true_branch.last_store;
last_store_loaded =
Expand All @@ -159,11 +221,18 @@ class AllocaOptimize : public IRVisitor {
if (true_branch.last_store == last_store &&
false_branch.last_store == last_store) {
// The last store didn't change.
last_store_loaded = last_store_loaded ||
true_branch.last_store_loaded ||
false_branch.last_store_loaded;
TI_ASSERT(!true_branch.stored_in_current_block);
TI_ASSERT(!false_branch.stored_in_current_block);
last_store_loaded =
last_store_loaded ||
true_branch.loaded_before_first_store_in_current_block ||
false_branch.loaded_before_first_store_in_current_block;
} else {
// The last store changed, so we can't eliminate last_store.
// The last store changed.
bool current_eliminable =
last_store && !last_store_loaded &&
!true_branch.loaded_before_first_store_in_current_block &&
!false_branch.loaded_before_first_store_in_current_block;
bool true_eliminable = true_branch.last_store != last_store &&
true_branch.last_store != nullptr &&
!true_branch.last_store_loaded;
Expand All @@ -176,6 +245,10 @@ class AllocaOptimize : public IRVisitor {
} else if (false_eliminable) {
last_store = false_branch.last_store;
last_store_loaded = false;
} else if (current_eliminable) {
TI_ASSERT(!true_branch.stored_in_current_block ||
!false_branch.stored_in_current_block);
last_store_loaded = false;
} else {
// Neither branch provides a eliminable local store.
last_store = nullptr;
Expand All @@ -187,11 +260,16 @@ class AllocaOptimize : public IRVisitor {
if (true_branch.last_atomic == last_atomic &&
false_branch.last_atomic == last_atomic) {
// The last AtomicOpStmt didn't change.
last_atomic_eliminable = last_atomic_eliminable &&
true_branch.last_atomic_eliminable &&
false_branch.last_atomic_eliminable;
last_atomic_eliminable =
last_atomic_eliminable &&
!true_branch.loaded_before_first_store_in_current_block &&
!false_branch.loaded_before_first_store_in_current_block;
} else {
// The last AtomicOpStmt changed, so we can't eliminate last_atomic.
// The last AtomicOpStmt changed.
bool current_eliminable =
last_atomic && last_atomic_eliminable &&
!true_branch.loaded_before_first_store_in_current_block &&
!false_branch.loaded_before_first_store_in_current_block;
bool true_eliminable = true_branch.last_atomic != last_atomic &&
true_branch.last_atomic != nullptr &&
true_branch.last_atomic_eliminable;
Expand All @@ -204,6 +282,10 @@ class AllocaOptimize : public IRVisitor {
} else if (false_eliminable) {
last_atomic = false_branch.last_atomic;
last_atomic_eliminable = true;
} else if (current_eliminable) {
TI_ASSERT(!true_branch.stored_in_current_block ||
!false_branch.stored_in_current_block);
last_atomic_eliminable = true;
} else {
// Neither branch provides a eliminable AtomicOpStmt.
last_atomic = nullptr;
Expand All @@ -230,7 +312,7 @@ class AllocaOptimize : public IRVisitor {
body->accept(this);
return;
}
AllocaOptimize loop(alloca_stmt);
AllocaOptimize loop(alloca_stmt, used_atomics);
loop.is_inside_loop = inside_loop_may_have_stores;
body->accept(&loop);

Expand Down Expand Up @@ -312,17 +394,11 @@ class AllocaOptimize : public IRVisitor {
throw IRModified();
}
if (last_atomic && last_atomic_eliminable) {
// The last AtomicOpStmt is never loaded.
// The last AtomicOpStmt is never used.
// last_atomic_valid == false means that it's in an IfStmt.
if (irpass::analysis::gather_statements(
block,
[&](Stmt *stmt) { return stmt->have_operand(last_atomic); })
.empty()) {
// The last AtomicOpStmt is never used.
// Eliminate the last AtomicOpStmt.
last_atomic->parent->erase(last_atomic);
throw IRModified();
}
// Eliminate the last AtomicOpStmt.
last_atomic->parent->erase(last_atomic);
throw IRModified();
}
if (!stored && !loaded) {
// Never stored and never loaded.
Expand All @@ -338,11 +414,13 @@ class AllocaOptimize : public IRVisitor {
class AllocaFindAndOptimize : public BasicStmtVisitor {
private:
std::unordered_set<int> visited;
std::unordered_set<AtomicOpStmt *> *used_atomics;

public:
using BasicStmtVisitor::visit;

AllocaFindAndOptimize() : visited() {
AllocaFindAndOptimize(std::unordered_set<AtomicOpStmt *> *used_atomics)
: visited(), used_atomics(used_atomics) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}
Expand All @@ -358,13 +436,14 @@ class AllocaFindAndOptimize : public BasicStmtVisitor {
void visit(AllocaStmt *alloca_stmt) override {
if (is_done(alloca_stmt))
return;
AllocaOptimize optimizer(alloca_stmt);
AllocaOptimize optimizer(alloca_stmt, used_atomics);
optimizer.run();
set_done(alloca_stmt);
}

static void run(IRNode *node) {
AllocaFindAndOptimize find_and_optimizer;
auto used_atomics = irpass::analysis::gather_used_atomics(node);
AllocaFindAndOptimize find_and_optimizer(&used_atomics);
while (true) {
bool modified = false;
try {
Expand Down
26 changes: 26 additions & 0 deletions tests/python/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,29 @@ def func():
val[None] = 10
func()
assert val[None] == 10


@ti.all_archs
def test_advanced_unused_store_elimination_if():
val = ti.var(ti.i32)
ti.root.place(val)

@ti.kernel
def func():
a = 1
if val[None]:
a = 2
if val[None]:
a = 3
else:
a = 4
val[None] = a
else:
val[None] = a


val[None] = 0
func()
assert val[None] == 1
func()
assert val[None] == 3

0 comments on commit 50ad29a

Please sign in to comment.