Skip to content

Commit

Permalink
[refactor] [ir] Remove legacy LaneAttribute usage from ExternalPtrStm…
Browse files Browse the repository at this point in the history
…t/GlobalPtrStmt (#5898)

* Handle ExternalPtrStmt

* Handle GlobalPtrStmt

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cancel accidental change

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Aug 29, 2022
1 parent dd394a2 commit fb62f1c
Show file tree
Hide file tree
Showing 25 changed files with 213 additions and 316 deletions.
12 changes: 6 additions & 6 deletions taichi/analysis/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
return AliasResult::different;
auto ptr1 = var1->as<ExternalPtrStmt>();
auto ptr2 = var2->as<ExternalPtrStmt>();
if (ptr1->base_ptrs[0] != ptr2->base_ptrs[0]) {
auto base1 = ptr1->base_ptrs[0]->as<ArgLoadStmt>();
auto base2 = ptr2->base_ptrs[0]->as<ArgLoadStmt>();
if (ptr1->base_ptr != ptr2->base_ptr) {
auto base1 = ptr1->base_ptr->as<ArgLoadStmt>();
auto base2 = ptr2->base_ptr->as<ArgLoadStmt>();
if (base1->arg_id != base2->arg_id) {
return AliasResult::different;
}
Expand All @@ -120,7 +120,7 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
// SNode::id.
auto get_snode_id = [](Stmt *s) {
if (auto ptr = s->cast<GlobalPtrStmt>()) {
return ptr->snodes[0]->id;
return ptr->snode->id;
} else if (auto get_child = s->cast<GetChStmt>()) {
return get_child->output_snode->id;
}
Expand All @@ -137,8 +137,8 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
if (var1->is<GlobalPtrStmt>() && var2->is<GlobalPtrStmt>()) {
auto ptr1 = var1->as<GlobalPtrStmt>();
auto ptr2 = var2->as<GlobalPtrStmt>();
auto snode = ptr1->snodes[0];
TI_ASSERT(snode == ptr2->snodes[0]);
auto snode = ptr1->snode;
TI_ASSERT(snode == ptr2->snode);
TI_ASSERT(ptr1->indices.size() == ptr2->indices.size());
bool uncertain = false;
for (int i = 0; i < (int)ptr1->indices.size(); i++) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/analysis/bls_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void BLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) {
if (!stmt->is<GlobalPtrStmt>())
return; // local alloca
auto ptr = stmt->as<GlobalPtrStmt>();
auto snode = ptr->snodes[0];
auto snode = ptr->snode;
if (!pads_->has(snode)) {
return;
}
Expand Down
10 changes: 4 additions & 6 deletions taichi/analysis/gather_snode_read_writes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ gather_snode_read_writes(IRNode *root) {
}
if (ptr) {
if (auto *global_ptr = ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
if (read)
accessed.first.emplace(snode);
if (write)
accessed.second.emplace(snode);
}
if (read)
accessed.first.emplace(global_ptr->snode);
if (write)
accessed.second.emplace(global_ptr->snode);
}
}
return false;
Expand Down
121 changes: 57 additions & 64 deletions taichi/analysis/gather_uniquely_accessed_pointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
}

void visit(GlobalPtrStmt *stmt) override {
auto snode = stmt->snode;
// mesh-for loop unique
if (stmt->indices.size() == 1 &&
stmt->indices[0]->is<MeshIndexConversionStmt>()) {
Expand All @@ -195,36 +196,30 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
}
if (idx->is<LoopIndexStmt>() &&
idx->as<LoopIndexStmt>()->is_mesh_index()) { // from-end access
for (auto &snode : stmt->snodes.data) {
if (rel_access_pointer_.find(snode) ==
rel_access_pointer_.end()) { // not accessed by neibhours yet
accessed_pointer_[snode] = stmt;
} else { // accessed by neibhours, so it's not unique
accessed_pointer_[snode] = nullptr;
}
if (rel_access_pointer_.find(snode) ==
rel_access_pointer_.end()) { // not accessed by neibhours yet
accessed_pointer_[snode] = stmt;
} else { // accessed by neibhours, so it's not unique
accessed_pointer_[snode] = nullptr;
}
} else { // to-end access
for (auto &snode : stmt->snodes.data) {
rel_access_pointer_[snode] = stmt;
accessed_pointer_[snode] =
nullptr; // from-end access should not be unique
}
rel_access_pointer_[snode] = stmt;
accessed_pointer_[snode] =
nullptr; // from-end access should not be unique
}
}
// Range-for / struct-for
for (auto &snode : stmt->snodes.data) {
auto accessed_ptr = accessed_pointer_.find(snode);
if (accessed_ptr == accessed_pointer_.end()) {
if (loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt)) {
accessed_pointer_[snode] = stmt;
} else {
accessed_pointer_[snode] = nullptr; // not loop-unique
}
auto accessed_ptr = accessed_pointer_.find(snode);
if (accessed_ptr == accessed_pointer_.end()) {
if (loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt)) {
accessed_pointer_[snode] = stmt;
} else {
if (!irpass::analysis::definitely_same_address(accessed_ptr->second,
stmt)) {
accessed_ptr->second = nullptr; // not uniquely accessed
}
accessed_pointer_[snode] = nullptr; // not loop-unique
}
} else {
if (!irpass::analysis::definitely_same_address(accessed_ptr->second,
stmt)) {
accessed_ptr->second = nullptr; // not uniquely accessed
}
}
}
Expand All @@ -233,50 +228,48 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
// A memory location of an ExternalPtrStmt depends on the indices
// If the accessed indices are loop unique,
// the accessed memory location is loop unique
for (auto base_ptr : stmt->base_ptrs.data) {
ArgLoadStmt *arg_load_stmt = base_ptr->as<ArgLoadStmt>();
int arg_id = arg_load_stmt->arg_id;
ArgLoadStmt *arg_load_stmt = stmt->base_ptr->as<ArgLoadStmt>();
int arg_id = arg_load_stmt->arg_id;

auto accessed_ptr = accessed_arr_pointer_.find(arg_id);
auto accessed_ptr = accessed_arr_pointer_.find(arg_id);

bool stmt_loop_unique =
loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt);
bool stmt_loop_unique =
loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt);

if (!stmt_loop_unique) {
accessed_arr_pointer_[arg_id] = nullptr; // not loop-unique
if (!stmt_loop_unique) {
accessed_arr_pointer_[arg_id] = nullptr; // not loop-unique
} else {
if (accessed_ptr == accessed_arr_pointer_.end()) {
// First time using arr @ arg_id
accessed_arr_pointer_[arg_id] = stmt;
} else {
if (accessed_ptr == accessed_arr_pointer_.end()) {
// First time using arr @ arg_id
accessed_arr_pointer_[arg_id] = stmt;
} else {
/**
* We know stmt->base_ptr and the previously recorded pointers
* are loop-unique. We need to figure out whether their loop-unique
* indices are the same while ignoring the others.
* e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed
* a[i, j, 1] and a[j, i, 2] are not uniquely accessed
* a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed
* This is a bit stricter than needed.
* e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed
* However this is probably not common and improvements can be made
* in a future patch.
*/
if (accessed_ptr->second) {
ExternalPtrStmt *other_ptr = accessed_ptr->second;
TI_ASSERT(stmt->indices.size() == other_ptr->indices.size());
for (int axis = 0; axis < stmt->indices.size(); axis++) {
Stmt *this_index = stmt->indices[axis];
Stmt *other_index = other_ptr->indices[axis];
// We only compare unique indices here.
// Since both pointers are loop-unique, all the unique indices
// need to be the same for both to be uniquely accessed
if (loop_unique_stmt_searcher_.is_partially_loop_unique(
this_index)) {
if (!irpass::analysis::same_value(this_index, other_index)) {
// Not equal -> not uniquely accessed
accessed_arr_pointer_[arg_id] = nullptr;
break;
}
/**
* We know stmt->base_ptr and the previously recorded pointers
* are loop-unique. We need to figure out whether their loop-unique
* indices are the same while ignoring the others.
* e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed
* a[i, j, 1] and a[j, i, 2] are not uniquely accessed
* a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed
* This is a bit stricter than needed.
* e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed
* However this is probably not common and improvements can be made
* in a future patch.
*/
if (accessed_ptr->second) {
ExternalPtrStmt *other_ptr = accessed_ptr->second;
TI_ASSERT(stmt->indices.size() == other_ptr->indices.size());
for (int axis = 0; axis < stmt->indices.size(); axis++) {
Stmt *this_index = stmt->indices[axis];
Stmt *other_index = other_ptr->indices[axis];
// We only compare unique indices here.
// Since both pointers are loop-unique, all the unique indices
// need to be the same for both to be uniquely accessed
if (loop_unique_stmt_searcher_.is_partially_loop_unique(
this_index)) {
if (!irpass::analysis::same_value(this_index, other_index)) {
// Not equal -> not uniquely accessed
accessed_arr_pointer_[arg_id] = nullptr;
break;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/analysis/mesh_bls_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void MeshBLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) {
auto idx = conv->idx;
if (conv_type == mesh::ConvType::g2r)
return;
auto snode = ptr->snodes[0];
auto snode = ptr->snode;
if (!caches_->has(snode)) {
if (auto_mesh_local_ &&
(flag == AccessFlag::accumulate ||
Expand Down
5 changes: 2 additions & 3 deletions taichi/analysis/same_statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@ class IRNodeComparator : public IRVisitor {
// And we cannot use irpass::analysis::definitely_same_address()
// directly because that function does not support id_map.

// TODO: Update this part if GlobalPtrStmt comes to have more fields
if (stmt->as<GlobalPtrStmt>()->snodes[0]->id !=
other->as<GlobalPtrStmt>()->snodes[0]->id) {
if (stmt->as<GlobalPtrStmt>()->snode->id !=
other->as<GlobalPtrStmt>()->snode->id) {
same = false;
return;
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class CCTransformer : public IRVisitor {

void visit(ExternalPtrStmt *stmt) override {
std::string offset = "0";
const auto *argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
const int arg_id = argload->arg_id;
const auto element_shape = stmt->element_shape;
const auto layout = stmt->element_dim < 0 ? ExternalArrayLayout::kAOS
Expand All @@ -177,7 +177,7 @@ class CCTransformer : public IRVisitor {
auto var =
define_var(cc_data_type_name(stmt->element_type().ptr_removed()) + " *",
stmt->raw_name());
emit("{} = {} + {};", var, stmt->base_ptrs[0]->raw_name(), offset);
emit("{} = {} + {};", var, stmt->base_ptr->raw_name(), offset);
}

void visit(ArgLoadStmt *stmt) override {
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1767,7 +1767,7 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) {
}

void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {
auto argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
auto argload = stmt->base_ptr->as<ArgLoadStmt>();
auto arg_id = argload->arg_id;
int num_indices = stmt->indices.size();
std::vector<llvm::Value *> sizes(num_indices);
Expand All @@ -1787,7 +1787,7 @@ void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {

auto dt = stmt->ret_type.ptr_removed();
auto base_ty = tlctx->get_data_type(dt);
auto base = builder->CreateBitCast(llvm_val[stmt->base_ptrs[0]],
auto base = builder->CreateBitCast(llvm_val[stmt->base_ptr],
llvm::PointerType::get(base_ty, 0));

auto linear_index = tlctx->get_constant(0);
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ class KernelCodegenImpl : public IRVisitor {
emit("{{");
{
ScopedIndent s(current_appender());
const auto *argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
const int arg_id = argload->arg_id;
const int num_indices = stmt->indices.size();
const auto &element_shape = stmt->element_shape;
Expand Down Expand Up @@ -492,7 +492,7 @@ class KernelCodegenImpl : public IRVisitor {

const auto dt = metal_data_type_name(stmt->element_type());
emit("device {} *{} = ({} + {});", dt, stmt->raw_name(),
stmt->base_ptrs[0]->raw_name(), linear_index_name);
stmt->base_ptr->raw_name(), linear_index_name);
}

void visit(GlobalTemporaryStmt *stmt) override {
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ class TaskCodegen : public IRVisitor {
// Used mostly for transferring data between host (e.g. numpy array) and
// device.
spirv::Value linear_offset = ir_->int_immediate_number(ir_->i32_type(), 0);
const auto *argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
const int arg_id = argload->arg_id;
{
const int num_indices = stmt->indices.size();
Expand Down
26 changes: 6 additions & 20 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,22 +335,14 @@ void CFGNode::gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const {
if (auto global_ptr = load_ptr->cast<GlobalPtrStmt>()) {
// Avoid computing the UD-chain if every SNode in this global ptr
// are already loaded because it can be time-consuming.
bool already_loaded = true;
for (auto &snode : global_ptr->snodes.data) {
if (snodes.count(snode) == 0) {
already_loaded = false;
break;
}
}
if (already_loaded) {
auto snode = global_ptr->snode;
if (snodes.count(snode) > 0) {
continue;
}
if (reach_in.find(global_ptr) != reach_in.end() &&
!contain_variable(killed_in_this_node, global_ptr)) {
// The UD-chain contains the value before this offloaded task.
for (auto &snode : global_ptr->snodes.data) {
snodes.insert(snode);
}
snodes.insert(snode);
}
}
}
Expand Down Expand Up @@ -458,9 +450,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
continue;
} else if (!is_parallel_executed ||
(atomic->dest->is<GlobalPtrStmt>() &&
atomic->dest->as<GlobalPtrStmt>()
->snodes[0]
->is_scalar())) {
atomic->dest->as<GlobalPtrStmt>()->snode->is_scalar())) {
// If this node is parallel executed, we can't weaken a global
// atomic operation to a global load.
// TODO: we can weaken it if it's element-wise (i.e. never
Expand Down Expand Up @@ -704,9 +694,7 @@ void ControlFlowGraph::live_variable_analysis(
}
if (auto *gptr = stmt->cast<GlobalPtrStmt>();
gptr && config_opt.has_value()) {
TI_ASSERT(gptr->snodes.size() == 1);
const bool res =
(config_opt->eliminable_snodes.count(gptr->snodes[0]) == 0);
const bool res = (config_opt->eliminable_snodes.count(gptr->snode) == 0);
return res;
}
// A global pointer that may be loaded after this kernel.
Expand Down Expand Up @@ -874,9 +862,7 @@ std::unordered_set<SNode *> ControlFlowGraph::gather_loaded_snodes() {
// Therefore we include the nodes[final_node]->reach_in in snodes.
for (auto &stmt : nodes[final_node]->reach_in) {
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
snodes.insert(snode);
}
snodes.insert(global_ptr->snode);
}
}

Expand Down
Loading

0 comments on commit fb62f1c

Please sign in to comment.