Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] [opt] Field value killing analysis #1929

Merged
merged 6 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions taichi/analysis/cfg_analysis.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "taichi/ir/analysis.h"
#include "taichi/ir/control_flow_graph.h"
#include "taichi/program/async_utils.h"

TLANG_NAMESPACE_BEGIN

namespace irpass::analysis {
void get_meta_input_value_states(IRNode *root, TaskMeta *meta) {
auto cfg = analysis::build_cfg(root);
auto snodes = cfg->gather_loaded_snodes();
for (auto &snode : snodes) {
meta->input_states.emplace(snode, AsyncState::Type::value);
}
}
} // namespace irpass::analysis
TLANG_NAMESPACE_END
3 changes: 3 additions & 0 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ enum AliasResult { same, uncertain, different };

class ControlFlowGraph;

class TaskMeta;

// IR Analysis
namespace irpass::analysis {

Expand All @@ -68,6 +70,7 @@ std::vector<Stmt *> gather_statements(IRNode *root,
std::unique_ptr<std::unordered_set<AtomicOpStmt *>> gather_used_atomics(
IRNode *root);
std::vector<Stmt *> get_load_pointers(Stmt *load_stmt);
void get_meta_input_value_states(IRNode *root, TaskMeta *meta);
Stmt *get_store_data(Stmt *store_stmt);
std::vector<Stmt *> get_store_destination(Stmt *store_stmt);
bool has_store_or_atomic(IRNode *root, const std::vector<Stmt *> &vars);
Expand Down
69 changes: 69 additions & 0 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,45 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) {
return modified;
}

void CFGNode::gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const {
// Gather the snodes which this CFGNode loads.
// Requires reaching definition analysis.
std::unordered_set<Stmt *> killed_in_this_node;
for (int i = begin_location; i < end_location; i++) {
auto stmt = block->statements[i].get();
auto load_ptrs = irpass::analysis::get_load_pointers(stmt);
for (auto &load_ptr : load_ptrs) {
if (auto global_ptr = load_ptr->cast<GlobalPtrStmt>()) {
// Avoid computing the UD-chain if every snode in this global ptr
// are already loaded.
bool already_loaded = true;
for (auto &snode : global_ptr->snodes.data) {
if (snodes.count(snode) == 0) {
already_loaded = false;
break;
}
}
if (already_loaded) {
continue;
}
if (reach_in.find(global_ptr) != reach_in.end() &&
!contain_variable(killed_in_this_node, global_ptr)) {
k-ye marked this conversation as resolved.
Show resolved Hide resolved
// The UD-chain contains the value before this offload.
for (auto &snode : global_ptr->snodes.data) {
snodes.insert(snode);
}
}
}
}
auto store_ptrs = irpass::analysis::get_store_destination(stmt);
for (auto &store_ptr : store_ptrs) {
if (store_ptr->is<GlobalPtrStmt>()) {
killed_in_this_node.insert(store_ptr);
}
}
}
}

void CFGNode::live_variable_analysis(bool after_lower_access) {
live_gen.clear();
live_kill.clear();
Expand Down Expand Up @@ -786,4 +825,34 @@ bool ControlFlowGraph::dead_store_elimination(
return modified;
}

std::unordered_set<SNode *> ControlFlowGraph::gather_loaded_snodes() {
TI_AUTO_PROF;
reaching_definition_analysis(/*after_lower_access=*/false);
const int num_nodes = size();
std::unordered_set<SNode *> snodes;

// Note: since global store may only partially modify a value state, the
// result (which contains the modified and unmodified part) actually needs a
// read from the previous version of the value state.
//
// I.e.,
// output_value_state = merge(input_value_state, written_part)
//
// Therefore we include the nodes[final_node]->reach_in in snodes.
for (auto &stmt : nodes[final_node]->reach_in) {
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
snodes.insert(snode);
}
}
}

for (int i = 0; i < num_nodes; i++) {
if (i != final_node) {
nodes[i]->gather_loaded_snodes(snodes);
}
}
return snodes;
}

TLANG_NAMESPACE_END
4 changes: 4 additions & 0 deletions taichi/ir/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CFGNode {
bool reach_kill_variable(Stmt *var) const;
Stmt *get_store_forwarding_data(Stmt *var, int position) const;
bool store_to_load_forwarding(bool after_lower_access);
void gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const;

void live_variable_analysis(bool after_lower_access);
bool dead_store_elimination(bool after_lower_access);
Expand Down Expand Up @@ -110,6 +111,9 @@ class ControlFlowGraph {
bool dead_store_elimination(
bool after_lower_access,
const std::optional<LiveVarAnalysisConfig> &lva_config_opt);

// Gather the SNodes this offload reads.
std::unordered_set<SNode *> gather_loaded_snodes();
};

TLANG_NAMESPACE_END
39 changes: 13 additions & 26 deletions taichi/program/async_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,43 +102,18 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
meta.name = t.kernel->name + "_" +
OffloadedStmt::task_type_name(root_stmt->task_type);
meta.type = root_stmt->task_type;
get_meta_input_value_states(root_stmt, &meta);
gather_statements(root_stmt, [&](Stmt *stmt) {
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (auto ptr = global_load->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_states.emplace(snode, AsyncState::Type::value);
}
}
}

// Note: since global store may only partially modify a value state, the
// result (which contains the modified and unmodified part) actually needs a
// read from the previous version of the value state.
//
// I.e.,
// output_value_state = merge(input_value_state, written_part)
//
// Therefore we include the value state in input_states.
//
// The only exception is that the task may completely overwrite the value
// state (e.g., for i in x: x[i] = 0). However, for now we are not yet
// able to detect that case, so we are being conservative here.

if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
if (auto ptr = global_store->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
if (!snode->is_scalar()) {
// TODO: This is ad-hoc, use value killing analysis
meta.input_states.emplace(snode, AsyncState::Type::value);
}
meta.output_states.emplace(snode, AsyncState::Type::value);
}
}
}
if (auto global_atomic = stmt->cast<AtomicOpStmt>()) {
if (auto ptr = global_atomic->dest->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_states.emplace(snode, AsyncState::Type::value);
meta.output_states.emplace(snode, AsyncState::Type::value);
}
}
Expand Down Expand Up @@ -173,6 +148,18 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
// TODO: handle SNodeOpStmt etc.
return false;
});

// We are being conservative here: if there are any non-element-wise
// accesses (e.g., a = x[i + 1]), we don't treat it as completely
// overwriting the value state (e.g., for i in x: x[i] = 0).
for (auto &state : meta.output_states) {
if (state.type == AsyncState::Type::value) {
if (meta.element_wise.find(state.snode) == meta.element_wise.end()) {
meta.input_states.insert(state);
}
}
}

if (root_stmt->task_type == OffloadedStmt::listgen) {
TI_ASSERT(root_stmt->snode->parent);
meta.snode = root_stmt->snode;
Expand Down
18 changes: 10 additions & 8 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,19 @@ void StateFlowGraph::insert_task(const TaskLaunchRecord &rec) {
input_state);
}
for (auto output_state : node->meta->output_states) {
latest_state_owner_[output_state] = node.get();
if (latest_state_readers_.find(output_state) ==
latest_state_readers_.end()) {
latest_state_readers_[output_state].insert(initial_node_);
if (latest_state_owner_.find(output_state) != latest_state_owner_.end()) {
// insert a WAW dependency edge
insert_state_flow(latest_state_owner_[output_state], node.get(),
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
output_state);
} else {
latest_state_readers_[output_state].insert(initial_node_);
}
}
latest_state_owner_[output_state] = node.get();
for (auto &d : latest_state_readers_[output_state]) {
// insert a dependency edge
// insert a WAR dependency edge
insert_state_flow(d, node.get(), output_state);
}
latest_state_readers_[output_state].clear();
Expand Down Expand Up @@ -951,13 +957,9 @@ bool StateFlowGraph::optimize_dead_store() {
}
bool used = false;
for (auto other : task->output_edges[s]) {
if (task->has_state_flow(s, other) &&
(other->meta->input_states.count(s) > 0)) {
if (task->has_state_flow(s, other)) {
// Check if this is a RAW dependency. For scalar SNodes, a WAW flow
// edge decades to a dependency edge.
//
// TODO: This is a hack that only works for scalar SNodes. The proper
// handling would require value killing analysis.
used = true;
} else {
// Note that a dependency edge does not count as an data usage
Expand Down
13 changes: 4 additions & 9 deletions taichi/program/state_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,15 @@ class StateFlowGraph {

// Note:
// Read-after-write leads to flow edges
// Write-after-write leads to flow edges
// Write-after-write leads to dependency edges
// Write-after-read leads to dependency edges
//
// So an edge is a data flow edge iff the starting node writes to the
// So an edge is a data flow edge iff the destination node reads the
// state.
//

if (is_initial_node) {
// The initial node is special.
return destination->meta->input_states.find(state) !=
destination->meta->input_states.end();
} else {
return meta->output_states.find(state) != meta->output_states.end();
}
return destination->meta->input_states.find(state) !=
destination->meta->input_states.end();
}

void disconnect_all();
Expand Down