Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] Cache demote_activation #1889

Merged
merged 4 commits into from
Sep 25, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,96 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
return result;
}

// TODO: make this an IR pass
class ConstExprPropagation {
public:
static std::unordered_set<Stmt *> run(
Block *block,
const std::function<bool(Stmt *)> &is_const_seed) {
std::unordered_set<Stmt *> const_stmts;

auto is_const = [&](Stmt *stmt) {
if (is_const_seed(stmt)) {
return true;
} else {
return const_stmts.find(stmt) != const_stmts.end();
}
};

for (auto &s : block->statements) {
if (is_const(s.get())) {
const_stmts.insert(s.get());
} else if (auto binary = s->cast<BinaryOpStmt>()) {
if (is_const(binary->lhs) && is_const(binary->rhs)) {
const_stmts.insert(s.get());
}
} else if (auto unary = s->cast<UnaryOpStmt>()) {
if (is_const(unary->operand)) {
const_stmts.insert(s.get());
}
} else {
// TODO: ...
}
}

return const_stmts;
}
};

IRHandle IRBank::demote_activation(IRHandle handle) {
auto &result = demote_activation_bank_[handle];
if (!result.empty()) {
return result;
}

std::unique_ptr<IRNode> new_ir = handle.clone();

OffloadedStmt *offload = new_ir->as<OffloadedStmt>();
Block *body = offload->body.get();

auto snode = offload->snode;
TI_ASSERT(snode != nullptr);

// TODO: for now we only deal with the top level. Is there an easy way to
// extend this part?
auto consts = ConstExprPropagation::run(body, [](Stmt *stmt) {
if (stmt->is<ConstStmt>()) {
return true;
} else if (stmt->is<LoopIndexStmt>())
return true;
return false;
});

bool demoted = false;
for (int k = 0; k < (int)body->statements.size(); k++) {
Stmt *stmt = body->statements[k].get();
if (auto ptr = stmt->cast<GlobalPtrStmt>(); ptr && ptr->activate) {
bool can_demote = true;
// TODO: test input mask?
for (auto ind : ptr->indices) {
if (consts.find(ind) == consts.end()) {
// non-constant index
can_demote = false;
}
}
if (can_demote) {
ptr->activate = false;
demoted = true;
}
}
}

if (!demoted) {
// Nothing demoted. Simply delete new_ir when this function returns.
result = handle;
k-ye marked this conversation as resolved.
Show resolved Hide resolved
return result;
}

result = IRHandle(new_ir.get(), get_hash(new_ir.get()));
insert(std::move(new_ir), result.hash());
return result;
}

ParallelExecutor::ParallelExecutor(int num_threads)
: num_threads(num_threads),
status(ExecutorStatus::uninitialized),
Expand Down
3 changes: 3 additions & 0 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class IRBank {
// Fuse handle_b into handle_a
IRHandle fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel);

IRHandle demote_activation(IRHandle handle);

std::unordered_map<IRHandle, TaskMeta> meta_bank_;
std::unordered_map<IRHandle, TaskFusionMeta> fusion_meta_bank_;

Expand All @@ -39,6 +41,7 @@ class IRBank {
std::unordered_map<IRHandle, std::unique_ptr<IRNode>> ir_bank_;
std::vector<std::unique_ptr<IRNode>> trash_bin; // prevent IR from deleted
std::unordered_map<std::pair<IRHandle, IRHandle>, IRHandle> fuse_bank_;
std::unordered_map<IRHandle, IRHandle> demote_activation_bank_;
};

class ParallelExecutor {
Expand Down
4 changes: 4 additions & 0 deletions taichi/program/async_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class IRHandle {
return hash_ == other_ir_handle.hash_;
}

bool operator!=(const IRHandle &other_ir_handle) const {
return !(*this == other_ir_handle);
}

bool operator<(const IRHandle &other_ir_handle) const {
return hash_ < other_ir_handle.hash_;
}
Expand Down
94 changes: 11 additions & 83 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -917,42 +917,6 @@ void StateFlowGraph::verify() {
topo_sort_nodes();
}

// TODO: make this an IR pass
class ConstExprPropagation {
public:
static std::unordered_set<Stmt *> run(
Block *block,
const std::function<bool(Stmt *)> &is_const_seed) {
std::unordered_set<Stmt *> const_stmts;

auto is_const = [&](Stmt *stmt) {
if (is_const_seed(stmt)) {
return true;
} else {
return const_stmts.find(stmt) != const_stmts.end();
}
};

for (auto &s : block->statements) {
if (is_const(s.get())) {
const_stmts.insert(s.get());
} else if (auto binary = s->cast<BinaryOpStmt>()) {
if (is_const(binary->lhs) && is_const(binary->rhs)) {
const_stmts.insert(s.get());
}
} else if (auto unary = s->cast<UnaryOpStmt>()) {
if (is_const(unary->operand)) {
const_stmts.insert(s.get());
}
} else {
// TODO: ...
}
}

return const_stmts;
}
};

bool StateFlowGraph::demote_activation() {
bool modified = false;

Expand Down Expand Up @@ -983,55 +947,19 @@ bool StateFlowGraph::demote_activation() {
if (nodes.size() <= 1)
continue;

auto snode = nodes[0]->meta->snode;

auto list_state = AsyncState(snode, AsyncState::Type::list);

TI_ASSERT(snode != nullptr);

std::unique_ptr<IRNode> new_ir = nodes[0]->rec.ir_handle.clone();

OffloadedStmt *offload = new_ir->as<OffloadedStmt>();
Block *body = offload->body.get();

// TODO: for now we only deal with the top level. Is there an easy way to
// extend this part?
auto consts = ConstExprPropagation::run(body, [](Stmt *stmt) {
if (stmt->is<ConstStmt>()) {
return true;
} else if (stmt->is<LoopIndexStmt>())
return true;
return false;
});

bool demoted = false;
for (int k = 0; k < (int)body->statements.size(); k++) {
Stmt *stmt = body->statements[k].get();
if (auto ptr = stmt->cast<GlobalPtrStmt>(); ptr && ptr->activate) {
bool can_demote = true;
// TODO: test input mask?
for (auto ind : ptr->indices) {
if (consts.find(ind) == consts.end()) {
// non-constant index
can_demote = false;
}
}
if (can_demote) {
modified = true;
ptr->activate = false;
demoted = true;
}
}
}
// TODO: cache this part
auto new_handle = IRHandle(new_ir.get(), ir_bank_->get_hash(new_ir.get()));
ir_bank_->insert(std::move(new_ir), new_handle.hash());
auto new_meta = get_task_meta(ir_bank_, nodes[0]->rec);
if (demoted) {
for (int j = 1; j < (int)nodes.size(); j++) {
auto new_handle = ir_bank_->demote_activation(nodes[0]->rec.ir_handle);
if (new_handle != nodes[0]->rec.ir_handle) {
modified = true;
nodes[1]->rec.ir_handle = new_handle;
nodes[1]->meta = get_task_meta(ir_bank_, nodes[1]->rec);
for (int j = 2; j < (int)nodes.size(); j++) {
nodes[j]->rec.ir_handle = new_handle;
nodes[j]->meta = new_meta;
nodes[j]->meta = nodes[1]->meta;
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved
}
// For every "demote_activation" call, we only optimize for
// a single key in std::map<std::pair<IRHandle, Node *>, std::vector<Node
// *>> tasks since the graph probably needs to be rebuild after demoting
// part of the tasks.
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
Expand Down