From ddfa1a644da13c6be576ea21e23d7501f4b88606 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 8 Oct 2020 17:53:49 +0800 Subject: [PATCH 1/6] [async] Field value killing analysis --- taichi/analysis/cfg_analysis.cpp | 16 +++++++ taichi/ir/analysis.h | 3 ++ taichi/ir/control_flow_graph.cpp | 69 +++++++++++++++++++++++++++++ taichi/ir/control_flow_graph.h | 4 ++ taichi/program/async_utils.cpp | 27 +---------- taichi/program/state_flow_graph.cpp | 17 +++---- taichi/program/state_flow_graph.h | 13 ++---- 7 files changed, 106 insertions(+), 43 deletions(-) create mode 100644 taichi/analysis/cfg_analysis.cpp diff --git a/taichi/analysis/cfg_analysis.cpp b/taichi/analysis/cfg_analysis.cpp new file mode 100644 index 0000000000000..b9c23cb078257 --- /dev/null +++ b/taichi/analysis/cfg_analysis.cpp @@ -0,0 +1,16 @@ +#include "taichi/ir/analysis.h" +#include "taichi/ir/control_flow_graph.h" +#include "taichi/program/async_utils.h" + +TLANG_NAMESPACE_BEGIN + +namespace irpass::analysis { +void get_meta_input_value_states(IRNode *root, TaskMeta *meta) { + auto cfg = analysis::build_cfg(root); + auto snodes = cfg->gather_loaded_snodes(); + for (auto &snode : snodes) { + meta->input_states.emplace(snode, AsyncState::Type::value); + } +} +} // namespace irpass::analysis +TLANG_NAMESPACE_END diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index cd05f5a528596..83c7ce6e99f51 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -51,6 +51,8 @@ enum AliasResult { same, uncertain, different }; class ControlFlowGraph; +class TaskMeta; + // IR Analysis namespace irpass::analysis { @@ -68,6 +70,7 @@ std::vector gather_statements(IRNode *root, std::unique_ptr> gather_used_atomics( IRNode *root); std::vector get_load_pointers(Stmt *load_stmt); +void get_meta_input_value_states(IRNode *root, TaskMeta *meta); Stmt *get_store_data(Stmt *store_stmt); std::vector get_store_destination(Stmt *store_stmt); bool has_store_or_atomic(IRNode *root, const std::vector &vars); diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index dd0f3743fabaf..2328db195719c 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -321,6 +321,45 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { return modified; } +void CFGNode::gather_loaded_snodes(std::unordered_set &snodes) const { + // Gather the snodes which this CFGNode loads. + // Requires reaching definition analysis. + std::unordered_set killed_in_this_node; + for (int i = begin_location; i < end_location; i++) { + auto stmt = block->statements[i].get(); + auto load_ptrs = irpass::analysis::get_load_pointers(stmt); + for (auto &load_ptr : load_ptrs) { + if (auto global_ptr = load_ptr->cast()) { + // Avoid computing the UD-chain if every snode in this global ptr + // are already loaded. + bool already_loaded = true; + for (auto &snode : global_ptr->snodes.data) { + if (snodes.count(snode) == 0) { + already_loaded = false; + break; + } + } + if (already_loaded) { + continue; + } + if (reach_in.find(global_ptr) != reach_in.end() && + !contain_variable(killed_in_this_node, global_ptr)) { + // The UD-chain contains the value before this offload. + for (auto &snode : global_ptr->snodes.data) { + snodes.insert(snode); + } + } + } + } + auto store_ptrs = irpass::analysis::get_store_destination(stmt); + for (auto &store_ptr : store_ptrs) { + if (store_ptr->is()) { + killed_in_this_node.insert(store_ptr); + } + } + } +} + void CFGNode::live_variable_analysis(bool after_lower_access) { live_gen.clear(); live_kill.clear(); @@ -786,4 +825,34 @@ bool ControlFlowGraph::dead_store_elimination( return modified; } +std::unordered_set ControlFlowGraph::gather_loaded_snodes() { + TI_AUTO_PROF; + reaching_definition_analysis(/*after_lower_access=*/false); + const int num_nodes = size(); + std::unordered_set snodes; + + // Note: since global store may only partially modify a value state, the + // result (which contains the modified and unmodified part) actually needs a + // read from the previous version of the value state. + // + // I.e., + // output_value_state = merge(input_value_state, written_part) + // + // Therefore we include the nodes[final_node]->reach_in in snodes. + for (auto &stmt : nodes[final_node]->reach_in) { + if (auto global_ptr = stmt->cast()) { + for (auto &snode : global_ptr->snodes.data) { + snodes.insert(snode); + } + } + } + + for (int i = 0; i < num_nodes; i++) { + if (i != final_node) { + nodes[i]->gather_loaded_snodes(snodes); + } + } + return snodes; +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index d82c00959aa4e..dd04c7d2a5377 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -62,6 +62,7 @@ class CFGNode { bool reach_kill_variable(Stmt *var) const; Stmt *get_store_forwarding_data(Stmt *var, int position) const; bool store_to_load_forwarding(bool after_lower_access); + void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); bool dead_store_elimination(bool after_lower_access); @@ -110,6 +111,9 @@ class ControlFlowGraph { bool dead_store_elimination( bool after_lower_access, const std::optional &lva_config_opt); + + // Gather the SNodes this offload reads. + std::unordered_set gather_loaded_snodes(); }; TLANG_NAMESPACE_END diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index e32d3b8cde2a3..8ae0ef26f3817 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -102,35 +102,11 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { meta.name = t.kernel->name + "_" + OffloadedStmt::task_type_name(root_stmt->task_type); meta.type = root_stmt->task_type; + get_meta_input_value_states(root_stmt, &meta); gather_statements(root_stmt, [&](Stmt *stmt) { - if (auto global_load = stmt->cast()) { - if (auto ptr = global_load->ptr->cast()) { - for (auto &snode : ptr->snodes.data) { - meta.input_states.emplace(snode, AsyncState::Type::value); - } - } - } - - // Note: since global store may only partially modify a value state, the - // result (which contains the modified and unmodified part) actually needs a - // read from the previous version of the value state. - // - // I.e., - // output_value_state = merge(input_value_state, written_part) - // - // Therefore we include the value state in input_states. - // - // The only exception is that the task may completely overwrite the value - // state (e.g., for i in x: x[i] = 0). However, for now we are not yet - // able to detect that case, so we are being conservative here. - if (auto global_store = stmt->cast()) { if (auto ptr = global_store->ptr->cast()) { for (auto &snode : ptr->snodes.data) { - if (!snode->is_scalar()) { - // TODO: This is ad-hoc, use value killing analysis - meta.input_states.emplace(snode, AsyncState::Type::value); - } meta.output_states.emplace(snode, AsyncState::Type::value); } } @@ -138,7 +114,6 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { if (auto global_atomic = stmt->cast()) { if (auto ptr = global_atomic->dest->cast()) { for (auto &snode : ptr->snodes.data) { - meta.input_states.emplace(snode, AsyncState::Type::value); meta.output_states.emplace(snode, AsyncState::Type::value); } } diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 5cbaaf42f4d6e..a1453a3d53e28 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -127,13 +127,18 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) { input_state); } for (auto output_state : node->meta->output_states) { - latest_state_owner_[output_state] = node.get(); if (latest_state_readers_.find(output_state) == latest_state_readers_.end()) { - latest_state_readers_[output_state].insert(initial_node_); + if (latest_state_owner_.find(output_state) != latest_state_owner_.end()) { + // insert a WAW dependency edge + insert_state_flow(latest_state_owner_[output_state], node.get(), output_state); + } else { + latest_state_readers_[output_state].insert(initial_node_); + } } + latest_state_owner_[output_state] = node.get(); for (auto &d : latest_state_readers_[output_state]) { - // insert a dependency edge + // insert a WAR dependency edge insert_state_flow(d, node.get(), output_state); } latest_state_readers_[output_state].clear(); @@ -951,13 +956,9 @@ bool StateFlowGraph::optimize_dead_store() { } bool used = false; for (auto other : task->output_edges[s]) { - if (task->has_state_flow(s, other) && - (other->meta->input_states.count(s) > 0)) { + if (task->has_state_flow(s, other)) { // Check if this is a RAW dependency. For scalar SNodes, a WAW flow // edge decades to a dependency edge. - // - // TODO: This is a hack that only works for scalar SNodes. The proper - // handling would require value killing analysis. used = true; } else { // Note that a dependency edge does not count as an data usage diff --git a/taichi/program/state_flow_graph.h b/taichi/program/state_flow_graph.h index dbdca12d98ed3..7a9c4f7fadd65 100644 --- a/taichi/program/state_flow_graph.h +++ b/taichi/program/state_flow_graph.h @@ -78,20 +78,15 @@ class StateFlowGraph { // Note: // Read-after-write leads to flow edges - // Write-after-write leads to flow edges + // Write-after-write leads to dependency edges // Write-after-read leads to dependency edges // - // So an edge is a data flow edge iff the starting node writes to the + // So an edge is a data flow edge iff the destination node reads the // state. // - if (is_initial_node) { - // The initial node is special. - return destination->meta->input_states.find(state) != - destination->meta->input_states.end(); - } else { - return meta->output_states.find(state) != meta->output_states.end(); - } + return destination->meta->input_states.find(state) != + destination->meta->input_states.end(); } void disconnect_all(); From a343840963952e8a52c92f3619cbb8bfb84e0a60 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 8 Oct 2020 18:09:55 +0800 Subject: [PATCH 2/6] element-wise --- taichi/program/async_utils.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index 8ae0ef26f3817..2b59dcbf8c217 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -148,6 +148,18 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { // TODO: handle SNodeOpStmt etc. return false; }); + + // We are being conservative here: if there are any non-element-wise + // accesses (e.g., a = x[i + 1]), we don't treat it as completely + // overwriting the value state (e.g., for i in x: x[i] = 0). + for (auto &state : meta.output_states) { + if (state.type == AsyncState::Type::value) { + if (meta.element_wise.find(state.snode) == meta.element_wise.end()) { + meta.input_states.insert(state); + } + } + } + if (root_stmt->task_type == OffloadedStmt::listgen) { TI_ASSERT(root_stmt->snode->parent); meta.snode = root_stmt->snode; From 861238d7683c5f5f7328ee59248e6f475ca44309 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 8 Oct 2020 06:11:00 -0400 Subject: [PATCH 3/6] [skip ci] enforce code format --- taichi/program/state_flow_graph.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index a1453a3d53e28..ffca6a358d222 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -131,7 +131,8 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) { latest_state_readers_.end()) { if (latest_state_owner_.find(output_state) != latest_state_owner_.end()) { // insert a WAW dependency edge - insert_state_flow(latest_state_owner_[output_state], node.get(), output_state); + insert_state_flow(latest_state_owner_[output_state], node.get(), + output_state); } else { latest_state_readers_[output_state].insert(initial_node_); } From 458847a473fd5dab08b5166f105f34cad059ac88 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Fri, 9 Oct 2020 17:38:40 +0800 Subject: [PATCH 4/6] Fix reaching_definition_analysis being too conservative --- taichi/ir/analysis.h | 2 +- taichi/ir/control_flow_graph.cpp | 17 +++++++++++------ tests/python/test_sfg.py | 1 - 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index 83c7ce6e99f51..d48635e7f343f 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -51,7 +51,7 @@ enum AliasResult { same, uncertain, different }; class ControlFlowGraph; -class TaskMeta; +struct TaskMeta; // IR Analysis namespace irpass::analysis { diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 2328db195719c..6b8daae92ac5d 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -639,14 +639,19 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { now->reach_out = now->reach_gen; for (auto stmt : now->reach_in) { auto store_ptrs = irpass::analysis::get_store_destination(stmt); - bool not_killed = store_ptrs.empty(); // for the case of a global pointer - for (auto store_ptr : store_ptrs) { - if (!now->reach_kill_variable(store_ptr)) { - not_killed = true; - break; + bool killed; + if (store_ptrs.empty()) { // the case of a global pointer + killed = now->reach_kill_variable(stmt); + } else { + killed = true; + for (auto store_ptr : store_ptrs) { + if (!now->reach_kill_variable(store_ptr)) { + killed = false; + break; + } } } - if (not_killed) { + if (!killed) { now->reach_out.insert(stmt); } } diff --git a/tests/python/test_sfg.py b/tests/python/test_sfg.py index 2b93ad881c8a3..1004309a20bbc 100644 --- a/tests/python/test_sfg.py +++ b/tests/python/test_sfg.py @@ -64,7 +64,6 @@ def serial_z(): @ti.test(require=ti.extension.async_mode, async_mode=True) def test_sfg_dead_store_elimination(): - ti.init(arch=ti.cpu, async_mode=True) n = 32 x = ti.field(dtype=float, shape=n, needs_grad=True) From af0c5d7e8e9c723446a5cfee1753dd7a5cfab87c Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Fri, 9 Oct 2020 17:44:10 +0800 Subject: [PATCH 5/6] Rename insert_state_flow to insert_edge --- taichi/program/state_flow_graph.cpp | 11 +++++------ taichi/program/state_flow_graph.h | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index ffca6a358d222..fe8752e773ec8 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -123,16 +123,15 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) { if (latest_state_owner_.find(input_state) == latest_state_owner_.end()) { latest_state_owner_[input_state] = initial_node_; } - insert_state_flow(latest_state_owner_[input_state], node.get(), - input_state); + insert_edge(latest_state_owner_[input_state], node.get(), input_state); } for (auto output_state : node->meta->output_states) { if (latest_state_readers_.find(output_state) == latest_state_readers_.end()) { if (latest_state_owner_.find(output_state) != latest_state_owner_.end()) { // insert a WAW dependency edge - insert_state_flow(latest_state_owner_[output_state], node.get(), - output_state); + insert_edge(latest_state_owner_[output_state], node.get(), + output_state); } else { latest_state_readers_[output_state].insert(initial_node_); } @@ -140,7 +139,7 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) { latest_state_owner_[output_state] = node.get(); for (auto &d : latest_state_readers_[output_state]) { // insert a WAR dependency edge - insert_state_flow(d, node.get(), output_state); + insert_edge(d, node.get(), output_state); } latest_state_readers_[output_state].clear(); } @@ -152,7 +151,7 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) { nodes_.push_back(std::move(node)); } -void StateFlowGraph::insert_state_flow(Node *from, Node *to, AsyncState state) { +void StateFlowGraph::insert_edge(Node *from, Node *to, AsyncState state) { TI_AUTO_PROF; TI_ASSERT(from != nullptr); TI_ASSERT(to != nullptr); diff --git a/taichi/program/state_flow_graph.h b/taichi/program/state_flow_graph.h index 7a9c4f7fadd65..3a1729f2a358b 100644 --- a/taichi/program/state_flow_graph.h +++ b/taichi/program/state_flow_graph.h @@ -122,7 +122,7 @@ class StateFlowGraph { void insert_task(const TaskLaunchRecord &rec); - void insert_state_flow(Node *from, Node *to, AsyncState state); + void insert_edge(Node *from, Node *to, AsyncState state); // Compute transitive closure for tasks in get_pending_tasks()[begin, end). std::pair, std::vector> From 1ca7db59dd0502102a1c7dcee4951a943341fcca Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Fri, 9 Oct 2020 18:10:56 +0800 Subject: [PATCH 6/6] Fix broken state flow chain --- taichi/program/state_flow_graph.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index fe8752e773ec8..9618c2e174834 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -126,8 +126,7 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) { insert_edge(latest_state_owner_[input_state], node.get(), input_state); } for (auto output_state : node->meta->output_states) { - if (latest_state_readers_.find(output_state) == - latest_state_readers_.end()) { + if (latest_state_readers_[output_state].empty()) { if (latest_state_owner_.find(output_state) != latest_state_owner_.end()) { // insert a WAW dependency edge insert_edge(latest_state_owner_[output_state], node.get(),