From 78e8c3f23ff0d059245a3b53b3b33e6c1c0c7797 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 24 Sep 2020 17:25:15 +0800 Subject: [PATCH 1/4] [async] Cache demote_activation --- taichi/program/async_engine.cpp | 90 ++++++++++++++++++++++++++++ taichi/program/async_engine.h | 3 + taichi/program/async_utils.h | 4 ++ taichi/program/state_flow_graph.cpp | 91 +++-------------------------- 4 files changed, 105 insertions(+), 83 deletions(-) diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 15d2c22d3f648..71dfed3bbab98 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -106,6 +106,96 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) { return result; } +// TODO: make this an IR pass +class ConstExprPropagation { + public: + static std::unordered_set run( + Block *block, + const std::function &is_const_seed) { + std::unordered_set const_stmts; + + auto is_const = [&](Stmt *stmt) { + if (is_const_seed(stmt)) { + return true; + } else { + return const_stmts.find(stmt) != const_stmts.end(); + } + }; + + for (auto &s : block->statements) { + if (is_const(s.get())) { + const_stmts.insert(s.get()); + } else if (auto binary = s->cast()) { + if (is_const(binary->lhs) && is_const(binary->rhs)) { + const_stmts.insert(s.get()); + } + } else if (auto unary = s->cast()) { + if (is_const(unary->operand)) { + const_stmts.insert(s.get()); + } + } else { + // TODO: ... + } + } + + return const_stmts; + } +}; + +IRHandle IRBank::demote_activation(IRHandle handle) { + auto &result = demote_activation_bank_[handle]; + if (!result.empty()) { + return result; + } + + std::unique_ptr new_ir = handle.clone(); + + OffloadedStmt *offload = new_ir->as(); + Block *body = offload->body.get(); + + auto snode = offload->snode; + TI_ASSERT(snode != nullptr); + + // TODO: for now we only deal with the top level. Is there an easy way to + // extend this part? + auto consts = ConstExprPropagation::run(body, [](Stmt *stmt) { + if (stmt->is()) { + return true; + } else if (stmt->is()) + return true; + return false; + }); + + bool demoted = false; + for (int k = 0; k < (int)body->statements.size(); k++) { + Stmt *stmt = body->statements[k].get(); + if (auto ptr = stmt->cast(); ptr && ptr->activate) { + bool can_demote = true; + // TODO: test input mask? + for (auto ind : ptr->indices) { + if (consts.find(ind) == consts.end()) { + // non-constant index + can_demote = false; + } + } + if (can_demote) { + ptr->activate = false; + demoted = true; + } + } + } + + if (!demoted) { + // Nothing demoted. Simply delete new_ir when this function returns. + result = handle; + return result; + } + + result = IRHandle(new_ir.get(), get_hash(new_ir.get())); + insert(std::move(new_ir), result.hash()); + return result; +} + ParallelExecutor::ParallelExecutor(int num_threads) : num_threads(num_threads), status(ExecutorStatus::uninitialized), diff --git a/taichi/program/async_engine.h b/taichi/program/async_engine.h index 236ea23e7dec5..9de8d84a5130e 100644 --- a/taichi/program/async_engine.h +++ b/taichi/program/async_engine.h @@ -31,6 +31,8 @@ class IRBank { // Fuse handle_b into handle_a IRHandle fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel); + IRHandle demote_activation(IRHandle handle); + std::unordered_map meta_bank_; std::unordered_map fusion_meta_bank_; @@ -39,6 +41,7 @@ class IRBank { std::unordered_map> ir_bank_; std::vector> trash_bin; // prevent IR from deleted std::unordered_map, IRHandle> fuse_bank_; + std::unordered_map demote_activation_bank_; }; class ParallelExecutor { diff --git a/taichi/program/async_utils.h b/taichi/program/async_utils.h index 2fcd4ec8b6f0a..96ec7517e7eda 100644 --- a/taichi/program/async_utils.h +++ b/taichi/program/async_utils.h @@ -41,6 +41,10 @@ class IRHandle { return hash_ == other_ir_handle.hash_; } + bool operator!=(const IRHandle &other_ir_handle) const { + return !(*this == other_ir_handle); + } + bool operator<(const IRHandle &other_ir_handle) const { return hash_ < other_ir_handle.hash_; } diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 83fe1d1081e2b..be062a3c1433d 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -917,42 +917,6 @@ void StateFlowGraph::verify() { topo_sort_nodes(); } -// TODO: make this an IR pass -class ConstExprPropagation { - public: - static std::unordered_set run( - Block *block, - const std::function &is_const_seed) { - std::unordered_set const_stmts; - - auto is_const = [&](Stmt *stmt) { - if (is_const_seed(stmt)) { - return true; - } else { - return const_stmts.find(stmt) != const_stmts.end(); - } - }; - - for (auto &s : block->statements) { - if (is_const(s.get())) { - const_stmts.insert(s.get()); - } else if (auto binary = s->cast()) { - if (is_const(binary->lhs) && is_const(binary->rhs)) { - const_stmts.insert(s.get()); - } - } else if (auto unary = s->cast()) { - if (is_const(unary->operand)) { - const_stmts.insert(s.get()); - } - } else { - // TODO: ... - } - } - - return const_stmts; - } -}; - bool StateFlowGraph::demote_activation() { bool modified = false; @@ -983,55 +947,16 @@ bool StateFlowGraph::demote_activation() { if (nodes.size() <= 1) continue; - auto snode = nodes[0]->meta->snode; - - auto list_state = AsyncState(snode, AsyncState::Type::list); - - TI_ASSERT(snode != nullptr); - - std::unique_ptr new_ir = nodes[0]->rec.ir_handle.clone(); - - OffloadedStmt *offload = new_ir->as(); - Block *body = offload->body.get(); - - // TODO: for now we only deal with the top level. Is there an easy way to - // extend this part? - auto consts = ConstExprPropagation::run(body, [](Stmt *stmt) { - if (stmt->is()) { - return true; - } else if (stmt->is()) - return true; - return false; - }); - - bool demoted = false; - for (int k = 0; k < (int)body->statements.size(); k++) { - Stmt *stmt = body->statements[k].get(); - if (auto ptr = stmt->cast(); ptr && ptr->activate) { - bool can_demote = true; - // TODO: test input mask? - for (auto ind : ptr->indices) { - if (consts.find(ind) == consts.end()) { - // non-constant index - can_demote = false; - } - } - if (can_demote) { - modified = true; - ptr->activate = false; - demoted = true; - } - } - } - // TODO: cache this part - auto new_handle = IRHandle(new_ir.get(), ir_bank_->get_hash(new_ir.get())); - ir_bank_->insert(std::move(new_ir), new_handle.hash()); - auto new_meta = get_task_meta(ir_bank_, nodes[0]->rec); - if (demoted) { - for (int j = 1; j < (int)nodes.size(); j++) { + auto new_handle = ir_bank_->demote_activation(nodes[0]->rec.ir_handle); + if (new_handle != nodes[0]->rec.ir_handle) { + modified = true; + nodes[1]->rec.ir_handle = new_handle; + nodes[1]->meta = get_task_meta(ir_bank_, nodes[1]->rec); + for (int j = 2; j < (int)nodes.size(); j++) { nodes[j]->rec.ir_handle = new_handle; - nodes[j]->meta = new_meta; + nodes[j]->meta = nodes[1]->meta; } + // TODO: do we need to break here? break; } } From a1182a341d5c7c8ac7d70218e9891e370ae44339 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Fri, 25 Sep 2020 11:38:17 +0800 Subject: [PATCH 2/4] [skip ci] Apply suggestions from code review Co-authored-by: Yuanming Hu --- taichi/program/state_flow_graph.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index be062a3c1433d..fa7471a09533f 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -956,7 +956,9 @@ bool StateFlowGraph::demote_activation() { nodes[j]->rec.ir_handle = new_handle; nodes[j]->meta = nodes[1]->meta; } - // TODO: do we need to break here? + // For every "demote_activation" call, we only optimize for + // a single key in std::map, std::vector> tasks + // since the graph probably needs to be rebuild after demoting part of the tasks. break; } } From 5ec64359b73ef47651e2a3e54dc61c392c0874bc Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 24 Sep 2020 23:38:38 -0400 Subject: [PATCH 3/4] [skip ci] enforce code format --- taichi/program/state_flow_graph.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index fa7471a09533f..da015d1020735 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -956,9 +956,10 @@ bool StateFlowGraph::demote_activation() { nodes[j]->rec.ir_handle = new_handle; nodes[j]->meta = nodes[1]->meta; } - // For every "demote_activation" call, we only optimize for - // a single key in std::map, std::vector> tasks - // since the graph probably needs to be rebuild after demoting part of the tasks. + // For every "demote_activation" call, we only optimize for + // a single key in std::map, std::vector> tasks since the graph probably needs to be rebuild after demoting + // part of the tasks. break; } } From 5951fad3aca18edd8ec7f32aa14ed749d80bead0 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Fri, 25 Sep 2020 11:41:09 +0800 Subject: [PATCH 4/4] [skip ci] Edit comment --- taichi/program/state_flow_graph.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index da015d1020735..f3a98f59cf1a6 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -956,9 +956,9 @@ bool StateFlowGraph::demote_activation() { nodes[j]->rec.ir_handle = new_handle; nodes[j]->meta = nodes[1]->meta; } - // For every "demote_activation" call, we only optimize for - // a single key in std::map, std::vector> tasks since the graph probably needs to be rebuild after demoting + // For every "demote_activation" call, we only optimize for a single key + // in std::map, std::vector> tasks + // since the graph probably needs to be rebuild after demoting // part of the tasks. break; }