diff --git a/taichi/analysis/cfg_analysis.cpp b/taichi/analysis/cfg_analysis.cpp new file mode 100644 index 0000000000000..b9c23cb078257 --- /dev/null +++ b/taichi/analysis/cfg_analysis.cpp @@ -0,0 +1,16 @@ +#include "taichi/ir/analysis.h" +#include "taichi/ir/control_flow_graph.h" +#include "taichi/program/async_utils.h" + +TLANG_NAMESPACE_BEGIN + +namespace irpass::analysis { +void get_meta_input_value_states(IRNode *root, TaskMeta *meta) { + auto cfg = analysis::build_cfg(root); + auto snodes = cfg->gather_loaded_snodes(); + for (auto &snode : snodes) { + meta->input_states.emplace(snode, AsyncState::Type::value); + } +} +} // namespace irpass::analysis +TLANG_NAMESPACE_END diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index cd05f5a528596..d48635e7f343f 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -51,6 +51,8 @@ enum AliasResult { same, uncertain, different }; class ControlFlowGraph; +struct TaskMeta; + // IR Analysis namespace irpass::analysis { @@ -68,6 +70,7 @@ std::vector gather_statements(IRNode *root, std::unique_ptr> gather_used_atomics( IRNode *root); std::vector get_load_pointers(Stmt *load_stmt); +void get_meta_input_value_states(IRNode *root, TaskMeta *meta); Stmt *get_store_data(Stmt *store_stmt); std::vector get_store_destination(Stmt *store_stmt); bool has_store_or_atomic(IRNode *root, const std::vector &vars); diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index dd0f3743fabaf..6b8daae92ac5d 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -321,6 +321,45 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { return modified; } +void CFGNode::gather_loaded_snodes(std::unordered_set &snodes) const { + // Gather the snodes which this CFGNode loads. + // Requires reaching definition analysis. + std::unordered_set killed_in_this_node; + for (int i = begin_location; i < end_location; i++) { + auto stmt = block->statements[i].get(); + auto load_ptrs = irpass::analysis::get_load_pointers(stmt); + for (auto &load_ptr : load_ptrs) { + if (auto global_ptr = load_ptr->cast()) { + // Avoid computing the UD-chain if every snode in this global ptr + // are already loaded. + bool already_loaded = true; + for (auto &snode : global_ptr->snodes.data) { + if (snodes.count(snode) == 0) { + already_loaded = false; + break; + } + } + if (already_loaded) { + continue; + } + if (reach_in.find(global_ptr) != reach_in.end() && + !contain_variable(killed_in_this_node, global_ptr)) { + // The UD-chain contains the value before this offload. + for (auto &snode : global_ptr->snodes.data) { + snodes.insert(snode); + } + } + } + } + auto store_ptrs = irpass::analysis::get_store_destination(stmt); + for (auto &store_ptr : store_ptrs) { + if (store_ptr->is()) { + killed_in_this_node.insert(store_ptr); + } + } + } +} + void CFGNode::live_variable_analysis(bool after_lower_access) { live_gen.clear(); live_kill.clear(); @@ -600,14 +639,19 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { now->reach_out = now->reach_gen; for (auto stmt : now->reach_in) { auto store_ptrs = irpass::analysis::get_store_destination(stmt); - bool not_killed = store_ptrs.empty(); // for the case of a global pointer - for (auto store_ptr : store_ptrs) { - if (!now->reach_kill_variable(store_ptr)) { - not_killed = true; - break; + bool killed; + if (store_ptrs.empty()) { // the case of a global pointer + killed = now->reach_kill_variable(stmt); + } else { + killed = true; + for (auto store_ptr : store_ptrs) { + if (!now->reach_kill_variable(store_ptr)) { + killed = false; + break; + } } } - if (not_killed) { + if (!killed) { now->reach_out.insert(stmt); } } @@ -786,4 +830,34 @@ bool ControlFlowGraph::dead_store_elimination( return modified; } +std::unordered_set ControlFlowGraph::gather_loaded_snodes() { + TI_AUTO_PROF; + reaching_definition_analysis(/*after_lower_access=*/false); + const int num_nodes = size(); + std::unordered_set snodes; + + // Note: since global store may only partially modify a value state, the + // result (which contains the modified and unmodified part) actually needs a + // read from the previous version of the value state. + // + // I.e., + // output_value_state = merge(input_value_state, written_part) + // + // Therefore we include the nodes[final_node]->reach_in in snodes. + for (auto &stmt : nodes[final_node]->reach_in) { + if (auto global_ptr = stmt->cast()) { + for (auto &snode : global_ptr->snodes.data) { + snodes.insert(snode); + } + } + } + + for (int i = 0; i < num_nodes; i++) { + if (i != final_node) { + nodes[i]->gather_loaded_snodes(snodes); + } + } + return snodes; +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index d82c00959aa4e..dd04c7d2a5377 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -62,6 +62,7 @@ class CFGNode { bool reach_kill_variable(Stmt *var) const; Stmt *get_store_forwarding_data(Stmt *var, int position) const; bool store_to_load_forwarding(bool after_lower_access); + void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); bool dead_store_elimination(bool after_lower_access); @@ -110,6 +111,9 @@ class ControlFlowGraph { bool dead_store_elimination( bool after_lower_access, const std::optional &lva_config_opt); + + // Gather the SNodes this offload reads. + std::unordered_set gather_loaded_snodes(); }; TLANG_NAMESPACE_END diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index e32d3b8cde2a3..2b59dcbf8c217 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -102,35 +102,11 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { meta.name = t.kernel->name + "_" + OffloadedStmt::task_type_name(root_stmt->task_type); meta.type = root_stmt->task_type; + get_meta_input_value_states(root_stmt, &meta); gather_statements(root_stmt, [&](Stmt *stmt) { - if (auto global_load = stmt->cast()) { - if (auto ptr = global_load->ptr->cast()) { - for (auto &snode : ptr->snodes.data) { - meta.input_states.emplace(snode, AsyncState::Type::value); - } - } - } - - // Note: since global store may only partially modify a value state, the - // result (which contains the modified and unmodified part) actually needs a - // read from the previous version of the value state. - // - // I.e., - // output_value_state = merge(input_value_state, written_part) - // - // Therefore we include the value state in input_states. - // - // The only exception is that the task may completely overwrite the value - // state (e.g., for i in x: x[i] = 0). However, for now we are not yet - // able to detect that case, so we are being conservative here. - if (auto global_store = stmt->cast()) { if (auto ptr = global_store->ptr->cast()) { for (auto &snode : ptr->snodes.data) { - if (!snode->is_scalar()) { - // TODO: This is ad-hoc, use value killing analysis - meta.input_states.emplace(snode, AsyncState::Type::value); - } meta.output_states.emplace(snode, AsyncState::Type::value); } } @@ -138,7 +114,6 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { if (auto global_atomic = stmt->cast()) { if (auto ptr = global_atomic->dest->cast()) { for (auto &snode : ptr->snodes.data) { - meta.input_states.emplace(snode, AsyncState::Type::value); meta.output_states.emplace(snode, AsyncState::Type::value); } } @@ -173,6 +148,18 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { // TODO: handle SNodeOpStmt etc. return false; }); + + // We are being conservative here: if there are any non-element-wise + // accesses (e.g., a = x[i + 1]), we don't treat it as completely + // overwriting the value state (e.g., for i in x: x[i] = 0). + for (auto &state : meta.output_states) { + if (state.type == AsyncState::Type::value) { + if (meta.element_wise.find(state.snode) == meta.element_wise.end()) { + meta.input_states.insert(state); + } + } + } + if (root_stmt->task_type == OffloadedStmt::listgen) { TI_ASSERT(root_stmt->snode->parent); meta.snode = root_stmt->snode; diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 5cbaaf42f4d6e..9618c2e174834 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -123,18 +123,22 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) { if (latest_state_owner_.find(input_state) == latest_state_owner_.end()) { latest_state_owner_[input_state] = initial_node_; } - insert_state_flow(latest_state_owner_[input_state], node.get(), - input_state); + insert_edge(latest_state_owner_[input_state], node.get(), input_state); } for (auto output_state : node->meta->output_states) { - latest_state_owner_[output_state] = node.get(); - if (latest_state_readers_.find(output_state) == - latest_state_readers_.end()) { - latest_state_readers_[output_state].insert(initial_node_); + if (latest_state_readers_[output_state].empty()) { + if (latest_state_owner_.find(output_state) != latest_state_owner_.end()) { + // insert a WAW dependency edge + insert_edge(latest_state_owner_[output_state], node.get(), + output_state); + } else { + latest_state_readers_[output_state].insert(initial_node_); + } } + latest_state_owner_[output_state] = node.get(); for (auto &d : latest_state_readers_[output_state]) { - // insert a dependency edge - insert_state_flow(d, node.get(), output_state); + // insert a WAR dependency edge + insert_edge(d, node.get(), output_state); } latest_state_readers_[output_state].clear(); } @@ -146,7 +150,7 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) { nodes_.push_back(std::move(node)); } -void StateFlowGraph::insert_state_flow(Node *from, Node *to, AsyncState state) { +void StateFlowGraph::insert_edge(Node *from, Node *to, AsyncState state) { TI_AUTO_PROF; TI_ASSERT(from != nullptr); TI_ASSERT(to != nullptr); @@ -951,13 +955,9 @@ bool StateFlowGraph::optimize_dead_store() { } bool used = false; for (auto other : task->output_edges[s]) { - if (task->has_state_flow(s, other) && - (other->meta->input_states.count(s) > 0)) { + if (task->has_state_flow(s, other)) { // Check if this is a RAW dependency. For scalar SNodes, a WAW flow // edge decades to a dependency edge. - // - // TODO: This is a hack that only works for scalar SNodes. The proper - // handling would require value killing analysis. used = true; } else { // Note that a dependency edge does not count as an data usage diff --git a/taichi/program/state_flow_graph.h b/taichi/program/state_flow_graph.h index dbdca12d98ed3..3a1729f2a358b 100644 --- a/taichi/program/state_flow_graph.h +++ b/taichi/program/state_flow_graph.h @@ -78,20 +78,15 @@ class StateFlowGraph { // Note: // Read-after-write leads to flow edges - // Write-after-write leads to flow edges + // Write-after-write leads to dependency edges // Write-after-read leads to dependency edges // - // So an edge is a data flow edge iff the starting node writes to the + // So an edge is a data flow edge iff the destination node reads the // state. // - if (is_initial_node) { - // The initial node is special. - return destination->meta->input_states.find(state) != - destination->meta->input_states.end(); - } else { - return meta->output_states.find(state) != meta->output_states.end(); - } + return destination->meta->input_states.find(state) != + destination->meta->input_states.end(); } void disconnect_all(); @@ -127,7 +122,7 @@ class StateFlowGraph { void insert_task(const TaskLaunchRecord &rec); - void insert_state_flow(Node *from, Node *to, AsyncState state); + void insert_edge(Node *from, Node *to, AsyncState state); // Compute transitive closure for tasks in get_pending_tasks()[begin, end). std::pair, std::vector> diff --git a/tests/python/test_sfg.py b/tests/python/test_sfg.py index 2b93ad881c8a3..1004309a20bbc 100644 --- a/tests/python/test_sfg.py +++ b/tests/python/test_sfg.py @@ -64,7 +64,6 @@ def serial_z(): @ti.test(require=ti.extension.async_mode, async_mode=True) def test_sfg_dead_store_elimination(): - ti.init(arch=ti.cpu, async_mode=True) n = 32 x = ti.field(dtype=float, shape=n, needs_grad=True)