Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] [ir] Remove legacy LaneAttribute usage from ExternalPtrStmt/GlobalPtrStmt #5898

Merged
merged 4 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions taichi/analysis/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
return AliasResult::different;
auto ptr1 = var1->as<ExternalPtrStmt>();
auto ptr2 = var2->as<ExternalPtrStmt>();
if (ptr1->base_ptrs[0] != ptr2->base_ptrs[0]) {
auto base1 = ptr1->base_ptrs[0]->as<ArgLoadStmt>();
auto base2 = ptr2->base_ptrs[0]->as<ArgLoadStmt>();
if (ptr1->base_ptr != ptr2->base_ptr) {
auto base1 = ptr1->base_ptr->as<ArgLoadStmt>();
auto base2 = ptr2->base_ptr->as<ArgLoadStmt>();
if (base1->arg_id != base2->arg_id) {
return AliasResult::different;
}
Expand All @@ -120,7 +120,7 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
// SNode::id.
auto get_snode_id = [](Stmt *s) {
if (auto ptr = s->cast<GlobalPtrStmt>()) {
return ptr->snodes[0]->id;
return ptr->snode->id;
} else if (auto get_child = s->cast<GetChStmt>()) {
return get_child->output_snode->id;
}
Expand All @@ -137,8 +137,8 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
if (var1->is<GlobalPtrStmt>() && var2->is<GlobalPtrStmt>()) {
auto ptr1 = var1->as<GlobalPtrStmt>();
auto ptr2 = var2->as<GlobalPtrStmt>();
auto snode = ptr1->snodes[0];
TI_ASSERT(snode == ptr2->snodes[0]);
auto snode = ptr1->snode;
TI_ASSERT(snode == ptr2->snode);
TI_ASSERT(ptr1->indices.size() == ptr2->indices.size());
bool uncertain = false;
for (int i = 0; i < (int)ptr1->indices.size(); i++) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/analysis/bls_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void BLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) {
if (!stmt->is<GlobalPtrStmt>())
return; // local alloca
auto ptr = stmt->as<GlobalPtrStmt>();
auto snode = ptr->snodes[0];
auto snode = ptr->snode;
if (!pads_->has(snode)) {
return;
}
Expand Down
10 changes: 4 additions & 6 deletions taichi/analysis/gather_snode_read_writes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ gather_snode_read_writes(IRNode *root) {
}
if (ptr) {
if (auto *global_ptr = ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
if (read)
accessed.first.emplace(snode);
if (write)
accessed.second.emplace(snode);
}
if (read)
accessed.first.emplace(global_ptr->snode);
if (write)
accessed.second.emplace(global_ptr->snode);
}
}
return false;
Expand Down
121 changes: 57 additions & 64 deletions taichi/analysis/gather_uniquely_accessed_pointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
}

void visit(GlobalPtrStmt *stmt) override {
auto snode = stmt->snode;
// mesh-for loop unique
if (stmt->indices.size() == 1 &&
stmt->indices[0]->is<MeshIndexConversionStmt>()) {
Expand All @@ -195,36 +196,30 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
}
if (idx->is<LoopIndexStmt>() &&
idx->as<LoopIndexStmt>()->is_mesh_index()) { // from-end access
for (auto &snode : stmt->snodes.data) {
if (rel_access_pointer_.find(snode) ==
rel_access_pointer_.end()) { // not accessed by neibhours yet
accessed_pointer_[snode] = stmt;
} else { // accessed by neibhours, so it's not unique
accessed_pointer_[snode] = nullptr;
}
if (rel_access_pointer_.find(snode) ==
rel_access_pointer_.end()) { // not accessed by neibhours yet
accessed_pointer_[snode] = stmt;
} else { // accessed by neibhours, so it's not unique
accessed_pointer_[snode] = nullptr;
}
} else { // to-end access
for (auto &snode : stmt->snodes.data) {
rel_access_pointer_[snode] = stmt;
accessed_pointer_[snode] =
nullptr; // from-end access should not be unique
}
rel_access_pointer_[snode] = stmt;
accessed_pointer_[snode] =
nullptr; // from-end access should not be unique
}
}
// Range-for / struct-for
for (auto &snode : stmt->snodes.data) {
auto accessed_ptr = accessed_pointer_.find(snode);
if (accessed_ptr == accessed_pointer_.end()) {
if (loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt)) {
accessed_pointer_[snode] = stmt;
} else {
accessed_pointer_[snode] = nullptr; // not loop-unique
}
auto accessed_ptr = accessed_pointer_.find(snode);
if (accessed_ptr == accessed_pointer_.end()) {
if (loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt)) {
accessed_pointer_[snode] = stmt;
} else {
if (!irpass::analysis::definitely_same_address(accessed_ptr->second,
stmt)) {
accessed_ptr->second = nullptr; // not uniquely accessed
}
accessed_pointer_[snode] = nullptr; // not loop-unique
}
} else {
if (!irpass::analysis::definitely_same_address(accessed_ptr->second,
stmt)) {
accessed_ptr->second = nullptr; // not uniquely accessed
}
}
}
Expand All @@ -233,50 +228,48 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
// A memory location of an ExternalPtrStmt depends on the indices
// If the accessed indices are loop unique,
// the accessed memory location is loop unique
for (auto base_ptr : stmt->base_ptrs.data) {
ArgLoadStmt *arg_load_stmt = base_ptr->as<ArgLoadStmt>();
int arg_id = arg_load_stmt->arg_id;
ArgLoadStmt *arg_load_stmt = stmt->base_ptr->as<ArgLoadStmt>();
int arg_id = arg_load_stmt->arg_id;

auto accessed_ptr = accessed_arr_pointer_.find(arg_id);
auto accessed_ptr = accessed_arr_pointer_.find(arg_id);

bool stmt_loop_unique =
loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt);
bool stmt_loop_unique =
loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt);

if (!stmt_loop_unique) {
accessed_arr_pointer_[arg_id] = nullptr; // not loop-unique
if (!stmt_loop_unique) {
accessed_arr_pointer_[arg_id] = nullptr; // not loop-unique
} else {
if (accessed_ptr == accessed_arr_pointer_.end()) {
// First time using arr @ arg_id
accessed_arr_pointer_[arg_id] = stmt;
} else {
if (accessed_ptr == accessed_arr_pointer_.end()) {
// First time using arr @ arg_id
accessed_arr_pointer_[arg_id] = stmt;
} else {
/**
* We know stmt->base_ptr and the previously recorded pointers
* are loop-unique. We need to figure out whether their loop-unique
* indices are the same while ignoring the others.
* e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed
* a[i, j, 1] and a[j, i, 2] are not uniquely accessed
* a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed
* This is a bit stricter than needed.
* e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed
* However this is probably not common and improvements can be made
* in a future patch.
*/
if (accessed_ptr->second) {
ExternalPtrStmt *other_ptr = accessed_ptr->second;
TI_ASSERT(stmt->indices.size() == other_ptr->indices.size());
for (int axis = 0; axis < stmt->indices.size(); axis++) {
Stmt *this_index = stmt->indices[axis];
Stmt *other_index = other_ptr->indices[axis];
// We only compare unique indices here.
// Since both pointers are loop-unique, all the unique indices
// need to be the same for both to be uniquely accessed
if (loop_unique_stmt_searcher_.is_partially_loop_unique(
this_index)) {
if (!irpass::analysis::same_value(this_index, other_index)) {
// Not equal -> not uniquely accessed
accessed_arr_pointer_[arg_id] = nullptr;
break;
}
/**
* We know stmt->base_ptr and the previously recorded pointers
* are loop-unique. We need to figure out whether their loop-unique
* indices are the same while ignoring the others.
* e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed
* a[i, j, 1] and a[j, i, 2] are not uniquely accessed
* a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed
* This is a bit stricter than needed.
* e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed
* However this is probably not common and improvements can be made
* in a future patch.
*/
if (accessed_ptr->second) {
ExternalPtrStmt *other_ptr = accessed_ptr->second;
TI_ASSERT(stmt->indices.size() == other_ptr->indices.size());
for (int axis = 0; axis < stmt->indices.size(); axis++) {
Stmt *this_index = stmt->indices[axis];
Stmt *other_index = other_ptr->indices[axis];
// We only compare unique indices here.
// Since both pointers are loop-unique, all the unique indices
// need to be the same for both to be uniquely accessed
if (loop_unique_stmt_searcher_.is_partially_loop_unique(
this_index)) {
if (!irpass::analysis::same_value(this_index, other_index)) {
// Not equal -> not uniquely accessed
accessed_arr_pointer_[arg_id] = nullptr;
break;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/analysis/mesh_bls_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void MeshBLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) {
auto idx = conv->idx;
if (conv_type == mesh::ConvType::g2r)
return;
auto snode = ptr->snodes[0];
auto snode = ptr->snode;
if (!caches_->has(snode)) {
if (auto_mesh_local_ &&
(flag == AccessFlag::accumulate ||
Expand Down
5 changes: 2 additions & 3 deletions taichi/analysis/same_statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@ class IRNodeComparator : public IRVisitor {
// And we cannot use irpass::analysis::definitely_same_address()
// directly because that function does not support id_map.

// TODO: Update this part if GlobalPtrStmt comes to have more fields
if (stmt->as<GlobalPtrStmt>()->snodes[0]->id !=
other->as<GlobalPtrStmt>()->snodes[0]->id) {
if (stmt->as<GlobalPtrStmt>()->snode->id !=
other->as<GlobalPtrStmt>()->snode->id) {
same = false;
return;
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class CCTransformer : public IRVisitor {

void visit(ExternalPtrStmt *stmt) override {
std::string offset = "0";
const auto *argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
const int arg_id = argload->arg_id;
const auto element_shape = stmt->element_shape;
const auto layout = stmt->element_dim < 0 ? ExternalArrayLayout::kAOS
Expand All @@ -177,7 +177,7 @@ class CCTransformer : public IRVisitor {
auto var =
define_var(cc_data_type_name(stmt->element_type().ptr_removed()) + " *",
stmt->raw_name());
emit("{} = {} + {};", var, stmt->base_ptrs[0]->raw_name(), offset);
emit("{} = {} + {};", var, stmt->base_ptr->raw_name(), offset);
}

void visit(ArgLoadStmt *stmt) override {
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1767,7 +1767,7 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) {
}

void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {
auto argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
auto argload = stmt->base_ptr->as<ArgLoadStmt>();
auto arg_id = argload->arg_id;
int num_indices = stmt->indices.size();
std::vector<llvm::Value *> sizes(num_indices);
Expand All @@ -1787,7 +1787,7 @@ void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {

auto dt = stmt->ret_type.ptr_removed();
auto base_ty = tlctx->get_data_type(dt);
auto base = builder->CreateBitCast(llvm_val[stmt->base_ptrs[0]],
auto base = builder->CreateBitCast(llvm_val[stmt->base_ptr],
llvm::PointerType::get(base_ty, 0));

auto linear_index = tlctx->get_constant(0);
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ class KernelCodegenImpl : public IRVisitor {
emit("{{");
{
ScopedIndent s(current_appender());
const auto *argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
const int arg_id = argload->arg_id;
const int num_indices = stmt->indices.size();
const auto &element_shape = stmt->element_shape;
Expand Down Expand Up @@ -492,7 +492,7 @@ class KernelCodegenImpl : public IRVisitor {

const auto dt = metal_data_type_name(stmt->element_type());
emit("device {} *{} = ({} + {});", dt, stmt->raw_name(),
stmt->base_ptrs[0]->raw_name(), linear_index_name);
stmt->base_ptr->raw_name(), linear_index_name);
}

void visit(GlobalTemporaryStmt *stmt) override {
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ class TaskCodegen : public IRVisitor {
// Used mostly for transferring data between host (e.g. numpy array) and
// device.
spirv::Value linear_offset = ir_->int_immediate_number(ir_->i32_type(), 0);
const auto *argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
const int arg_id = argload->arg_id;
{
const int num_indices = stmt->indices.size();
Expand Down
26 changes: 6 additions & 20 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,22 +335,14 @@ void CFGNode::gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const {
if (auto global_ptr = load_ptr->cast<GlobalPtrStmt>()) {
// Avoid computing the UD-chain if every SNode in this global ptr
// are already loaded because it can be time-consuming.
bool already_loaded = true;
for (auto &snode : global_ptr->snodes.data) {
if (snodes.count(snode) == 0) {
already_loaded = false;
break;
}
}
if (already_loaded) {
auto snode = global_ptr->snode;
if (snodes.count(snode) > 0) {
continue;
}
if (reach_in.find(global_ptr) != reach_in.end() &&
!contain_variable(killed_in_this_node, global_ptr)) {
// The UD-chain contains the value before this offloaded task.
for (auto &snode : global_ptr->snodes.data) {
snodes.insert(snode);
}
snodes.insert(snode);
}
}
}
Expand Down Expand Up @@ -458,9 +450,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
continue;
} else if (!is_parallel_executed ||
(atomic->dest->is<GlobalPtrStmt>() &&
atomic->dest->as<GlobalPtrStmt>()
->snodes[0]
->is_scalar())) {
atomic->dest->as<GlobalPtrStmt>()->snode->is_scalar())) {
// If this node is parallel executed, we can't weaken a global
// atomic operation to a global load.
// TODO: we can weaken it if it's element-wise (i.e. never
Expand Down Expand Up @@ -704,9 +694,7 @@ void ControlFlowGraph::live_variable_analysis(
}
if (auto *gptr = stmt->cast<GlobalPtrStmt>();
gptr && config_opt.has_value()) {
TI_ASSERT(gptr->snodes.size() == 1);
const bool res =
(config_opt->eliminable_snodes.count(gptr->snodes[0]) == 0);
const bool res = (config_opt->eliminable_snodes.count(gptr->snode) == 0);
return res;
}
// A global pointer that may be loaded after this kernel.
Expand Down Expand Up @@ -874,9 +862,7 @@ std::unordered_set<SNode *> ControlFlowGraph::gather_loaded_snodes() {
// Therefore we include the nodes[final_node]->reach_in in snodes.
for (auto &stmt : nodes[final_node]->reach_in) {
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
snodes.insert(snode);
}
snodes.insert(global_ptr->snode);
}
}

Expand Down
Loading