Skip to content

Commit

Permalink
[async] [opt] Field value killing analysis (#1929)
Browse files Browse the repository at this point in the history
* [async] Field value killing analysis

* element-wise

* [skip ci] enforce code format

* Fix reaching_definition_analysis being too conservative

* Rename insert_state_flow to insert_edge

* Fix broken state flow chain

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
xumingkuan and taichi-gardener authored Oct 9, 2020
1 parent cfed2d9 commit 4fa0d60
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 57 deletions.
16 changes: 16 additions & 0 deletions taichi/analysis/cfg_analysis.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "taichi/ir/analysis.h"
#include "taichi/ir/control_flow_graph.h"
#include "taichi/program/async_utils.h"

TLANG_NAMESPACE_BEGIN

namespace irpass::analysis {
void get_meta_input_value_states(IRNode *root, TaskMeta *meta) {
auto cfg = analysis::build_cfg(root);
auto snodes = cfg->gather_loaded_snodes();
for (auto &snode : snodes) {
meta->input_states.emplace(snode, AsyncState::Type::value);
}
}
} // namespace irpass::analysis
TLANG_NAMESPACE_END
3 changes: 3 additions & 0 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ enum AliasResult { same, uncertain, different };

class ControlFlowGraph;

struct TaskMeta;

// IR Analysis
namespace irpass::analysis {

Expand All @@ -68,6 +70,7 @@ std::vector<Stmt *> gather_statements(IRNode *root,
std::unique_ptr<std::unordered_set<AtomicOpStmt *>> gather_used_atomics(
IRNode *root);
std::vector<Stmt *> get_load_pointers(Stmt *load_stmt);
void get_meta_input_value_states(IRNode *root, TaskMeta *meta);
Stmt *get_store_data(Stmt *store_stmt);
std::vector<Stmt *> get_store_destination(Stmt *store_stmt);
bool has_store_or_atomic(IRNode *root, const std::vector<Stmt *> &vars);
Expand Down
86 changes: 80 additions & 6 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,45 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) {
return modified;
}

void CFGNode::gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const {
// Gather the snodes which this CFGNode loads.
// Requires reaching definition analysis.
std::unordered_set<Stmt *> killed_in_this_node;
for (int i = begin_location; i < end_location; i++) {
auto stmt = block->statements[i].get();
auto load_ptrs = irpass::analysis::get_load_pointers(stmt);
for (auto &load_ptr : load_ptrs) {
if (auto global_ptr = load_ptr->cast<GlobalPtrStmt>()) {
// Avoid computing the UD-chain if every snode in this global ptr
// are already loaded.
bool already_loaded = true;
for (auto &snode : global_ptr->snodes.data) {
if (snodes.count(snode) == 0) {
already_loaded = false;
break;
}
}
if (already_loaded) {
continue;
}
if (reach_in.find(global_ptr) != reach_in.end() &&
!contain_variable(killed_in_this_node, global_ptr)) {
// The UD-chain contains the value before this offload.
for (auto &snode : global_ptr->snodes.data) {
snodes.insert(snode);
}
}
}
}
auto store_ptrs = irpass::analysis::get_store_destination(stmt);
for (auto &store_ptr : store_ptrs) {
if (store_ptr->is<GlobalPtrStmt>()) {
killed_in_this_node.insert(store_ptr);
}
}
}
}

void CFGNode::live_variable_analysis(bool after_lower_access) {
live_gen.clear();
live_kill.clear();
Expand Down Expand Up @@ -600,14 +639,19 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) {
now->reach_out = now->reach_gen;
for (auto stmt : now->reach_in) {
auto store_ptrs = irpass::analysis::get_store_destination(stmt);
bool not_killed = store_ptrs.empty(); // for the case of a global pointer
for (auto store_ptr : store_ptrs) {
if (!now->reach_kill_variable(store_ptr)) {
not_killed = true;
break;
bool killed;
if (store_ptrs.empty()) { // the case of a global pointer
killed = now->reach_kill_variable(stmt);
} else {
killed = true;
for (auto store_ptr : store_ptrs) {
if (!now->reach_kill_variable(store_ptr)) {
killed = false;
break;
}
}
}
if (not_killed) {
if (!killed) {
now->reach_out.insert(stmt);
}
}
Expand Down Expand Up @@ -786,4 +830,34 @@ bool ControlFlowGraph::dead_store_elimination(
return modified;
}

std::unordered_set<SNode *> ControlFlowGraph::gather_loaded_snodes() {
TI_AUTO_PROF;
reaching_definition_analysis(/*after_lower_access=*/false);
const int num_nodes = size();
std::unordered_set<SNode *> snodes;

// Note: since global store may only partially modify a value state, the
// result (which contains the modified and unmodified part) actually needs a
// read from the previous version of the value state.
//
// I.e.,
// output_value_state = merge(input_value_state, written_part)
//
// Therefore we include the nodes[final_node]->reach_in in snodes.
for (auto &stmt : nodes[final_node]->reach_in) {
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
snodes.insert(snode);
}
}
}

for (int i = 0; i < num_nodes; i++) {
if (i != final_node) {
nodes[i]->gather_loaded_snodes(snodes);
}
}
return snodes;
}

TLANG_NAMESPACE_END
4 changes: 4 additions & 0 deletions taichi/ir/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CFGNode {
bool reach_kill_variable(Stmt *var) const;
Stmt *get_store_forwarding_data(Stmt *var, int position) const;
bool store_to_load_forwarding(bool after_lower_access);
void gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const;

void live_variable_analysis(bool after_lower_access);
bool dead_store_elimination(bool after_lower_access);
Expand Down Expand Up @@ -110,6 +111,9 @@ class ControlFlowGraph {
bool dead_store_elimination(
bool after_lower_access,
const std::optional<LiveVarAnalysisConfig> &lva_config_opt);

// Gather the SNodes this offload reads.
std::unordered_set<SNode *> gather_loaded_snodes();
};

TLANG_NAMESPACE_END
39 changes: 13 additions & 26 deletions taichi/program/async_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,43 +102,18 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
meta.name = t.kernel->name + "_" +
OffloadedStmt::task_type_name(root_stmt->task_type);
meta.type = root_stmt->task_type;
get_meta_input_value_states(root_stmt, &meta);
gather_statements(root_stmt, [&](Stmt *stmt) {
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (auto ptr = global_load->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_states.emplace(snode, AsyncState::Type::value);
}
}
}

// Note: since global store may only partially modify a value state, the
// result (which contains the modified and unmodified part) actually needs a
// read from the previous version of the value state.
//
// I.e.,
// output_value_state = merge(input_value_state, written_part)
//
// Therefore we include the value state in input_states.
//
// The only exception is that the task may completely overwrite the value
// state (e.g., for i in x: x[i] = 0). However, for now we are not yet
// able to detect that case, so we are being conservative here.

if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
if (auto ptr = global_store->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
if (!snode->is_scalar()) {
// TODO: This is ad-hoc, use value killing analysis
meta.input_states.emplace(snode, AsyncState::Type::value);
}
meta.output_states.emplace(snode, AsyncState::Type::value);
}
}
}
if (auto global_atomic = stmt->cast<AtomicOpStmt>()) {
if (auto ptr = global_atomic->dest->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_states.emplace(snode, AsyncState::Type::value);
meta.output_states.emplace(snode, AsyncState::Type::value);
}
}
Expand Down Expand Up @@ -173,6 +148,18 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
// TODO: handle SNodeOpStmt etc.
return false;
});

// We are being conservative here: if there are any non-element-wise
// accesses (e.g., a = x[i + 1]), we don't treat it as completely
// overwriting the value state (e.g., for i in x: x[i] = 0).
for (auto &state : meta.output_states) {
if (state.type == AsyncState::Type::value) {
if (meta.element_wise.find(state.snode) == meta.element_wise.end()) {
meta.input_states.insert(state);
}
}
}

if (root_stmt->task_type == OffloadedStmt::listgen) {
TI_ASSERT(root_stmt->snode->parent);
meta.snode = root_stmt->snode;
Expand Down
28 changes: 14 additions & 14 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,22 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) {
if (latest_state_owner_.find(input_state) == latest_state_owner_.end()) {
latest_state_owner_[input_state] = initial_node_;
}
insert_state_flow(latest_state_owner_[input_state], node.get(),
input_state);
insert_edge(latest_state_owner_[input_state], node.get(), input_state);
}
for (auto output_state : node->meta->output_states) {
latest_state_owner_[output_state] = node.get();
if (latest_state_readers_.find(output_state) ==
latest_state_readers_.end()) {
latest_state_readers_[output_state].insert(initial_node_);
if (latest_state_readers_[output_state].empty()) {
if (latest_state_owner_.find(output_state) != latest_state_owner_.end()) {
// insert a WAW dependency edge
insert_edge(latest_state_owner_[output_state], node.get(),
output_state);
} else {
latest_state_readers_[output_state].insert(initial_node_);
}
}
latest_state_owner_[output_state] = node.get();
for (auto &d : latest_state_readers_[output_state]) {
// insert a dependency edge
insert_state_flow(d, node.get(), output_state);
// insert a WAR dependency edge
insert_edge(d, node.get(), output_state);
}
latest_state_readers_[output_state].clear();
}
Expand All @@ -146,7 +150,7 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) {
nodes_.push_back(std::move(node));
}

void StateFlowGraph::insert_state_flow(Node *from, Node *to, AsyncState state) {
void StateFlowGraph::insert_edge(Node *from, Node *to, AsyncState state) {
TI_AUTO_PROF;
TI_ASSERT(from != nullptr);
TI_ASSERT(to != nullptr);
Expand Down Expand Up @@ -951,13 +955,9 @@ bool StateFlowGraph::optimize_dead_store() {
}
bool used = false;
for (auto other : task->output_edges[s]) {
if (task->has_state_flow(s, other) &&
(other->meta->input_states.count(s) > 0)) {
if (task->has_state_flow(s, other)) {
// Check if this is a RAW dependency. For scalar SNodes, a WAW flow
// edge decades to a dependency edge.
//
// TODO: This is a hack that only works for scalar SNodes. The proper
// handling would require value killing analysis.
used = true;
} else {
// Note that a dependency edge does not count as an data usage
Expand Down
15 changes: 5 additions & 10 deletions taichi/program/state_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,15 @@ class StateFlowGraph {

// Note:
// Read-after-write leads to flow edges
// Write-after-write leads to flow edges
// Write-after-write leads to dependency edges
// Write-after-read leads to dependency edges
//
// So an edge is a data flow edge iff the starting node writes to the
// So an edge is a data flow edge iff the destination node reads the
// state.
//

if (is_initial_node) {
// The initial node is special.
return destination->meta->input_states.find(state) !=
destination->meta->input_states.end();
} else {
return meta->output_states.find(state) != meta->output_states.end();
}
return destination->meta->input_states.find(state) !=
destination->meta->input_states.end();
}

void disconnect_all();
Expand Down Expand Up @@ -127,7 +122,7 @@ class StateFlowGraph {

void insert_task(const TaskLaunchRecord &rec);

void insert_state_flow(Node *from, Node *to, AsyncState state);
void insert_edge(Node *from, Node *to, AsyncState state);

// Compute transitive closure for tasks in get_pending_tasks()[begin, end).
std::pair<std::vector<bit::Bitset>, std::vector<bit::Bitset>>
Expand Down
1 change: 0 additions & 1 deletion tests/python/test_sfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def serial_z():

@ti.test(require=ti.extension.async_mode, async_mode=True)
def test_sfg_dead_store_elimination():
ti.init(arch=ti.cpu, async_mode=True)
n = 32

x = ti.field(dtype=float, shape=n, needs_grad=True)
Expand Down

0 comments on commit 4fa0d60

Please sign in to comment.