diff --git a/taichi/analysis/alias_analysis.cpp b/taichi/analysis/alias_analysis.cpp index 3e401e3768e67..3fba06827426c 100644 --- a/taichi/analysis/alias_analysis.cpp +++ b/taichi/analysis/alias_analysis.cpp @@ -96,9 +96,9 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) { return AliasResult::different; auto ptr1 = var1->as(); auto ptr2 = var2->as(); - if (ptr1->base_ptrs[0] != ptr2->base_ptrs[0]) { - auto base1 = ptr1->base_ptrs[0]->as(); - auto base2 = ptr2->base_ptrs[0]->as(); + if (ptr1->base_ptr != ptr2->base_ptr) { + auto base1 = ptr1->base_ptr->as(); + auto base2 = ptr2->base_ptr->as(); if (base1->arg_id != base2->arg_id) { return AliasResult::different; } @@ -120,7 +120,7 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) { // SNode::id. auto get_snode_id = [](Stmt *s) { if (auto ptr = s->cast()) { - return ptr->snodes[0]->id; + return ptr->snode->id; } else if (auto get_child = s->cast()) { return get_child->output_snode->id; } @@ -137,8 +137,8 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) { if (var1->is() && var2->is()) { auto ptr1 = var1->as(); auto ptr2 = var2->as(); - auto snode = ptr1->snodes[0]; - TI_ASSERT(snode == ptr2->snodes[0]); + auto snode = ptr1->snode; + TI_ASSERT(snode == ptr2->snode); TI_ASSERT(ptr1->indices.size() == ptr2->indices.size()); bool uncertain = false; for (int i = 0; i < (int)ptr1->indices.size(); i++) { diff --git a/taichi/analysis/bls_analyzer.cpp b/taichi/analysis/bls_analyzer.cpp index 86eb56733615a..2db07a0e7e0a1 100644 --- a/taichi/analysis/bls_analyzer.cpp +++ b/taichi/analysis/bls_analyzer.cpp @@ -36,7 +36,7 @@ void BLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) { if (!stmt->is()) return; // local alloca auto ptr = stmt->as(); - auto snode = ptr->snodes[0]; + auto snode = ptr->snode; if (!pads_->has(snode)) { return; } diff --git a/taichi/analysis/gather_snode_read_writes.cpp b/taichi/analysis/gather_snode_read_writes.cpp index 45abb217ad84c..270b6a88d074d 100644 --- a/taichi/analysis/gather_snode_read_writes.cpp +++ b/taichi/analysis/gather_snode_read_writes.cpp @@ -28,12 +28,10 @@ gather_snode_read_writes(IRNode *root) { } if (ptr) { if (auto *global_ptr = ptr->cast()) { - for (auto &snode : global_ptr->snodes.data) { - if (read) - accessed.first.emplace(snode); - if (write) - accessed.second.emplace(snode); - } + if (read) + accessed.first.emplace(global_ptr->snode); + if (write) + accessed.second.emplace(global_ptr->snode); } } return false; diff --git a/taichi/analysis/gather_uniquely_accessed_pointers.cpp b/taichi/analysis/gather_uniquely_accessed_pointers.cpp index 1d2a770120d89..e2cfd21af6f30 100644 --- a/taichi/analysis/gather_uniquely_accessed_pointers.cpp +++ b/taichi/analysis/gather_uniquely_accessed_pointers.cpp @@ -185,6 +185,7 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor { } void visit(GlobalPtrStmt *stmt) override { + auto snode = stmt->snode; // mesh-for loop unique if (stmt->indices.size() == 1 && stmt->indices[0]->is()) { @@ -195,36 +196,30 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor { } if (idx->is() && idx->as()->is_mesh_index()) { // from-end access - for (auto &snode : stmt->snodes.data) { - if (rel_access_pointer_.find(snode) == - rel_access_pointer_.end()) { // not accessed by neibhours yet - accessed_pointer_[snode] = stmt; - } else { // accessed by neibhours, so it's not unique - accessed_pointer_[snode] = nullptr; - } + if (rel_access_pointer_.find(snode) == + rel_access_pointer_.end()) { // not accessed by neibhours yet + accessed_pointer_[snode] = stmt; + } else { // accessed by neibhours, so it's not unique + accessed_pointer_[snode] = nullptr; } } else { // to-end access - for (auto &snode : stmt->snodes.data) { - rel_access_pointer_[snode] = stmt; - accessed_pointer_[snode] = - nullptr; // from-end access should not be unique - } + rel_access_pointer_[snode] = stmt; + accessed_pointer_[snode] = + nullptr; // from-end access should not be unique } } // Range-for / struct-for - for (auto &snode : stmt->snodes.data) { - auto accessed_ptr = accessed_pointer_.find(snode); - if (accessed_ptr == accessed_pointer_.end()) { - if (loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt)) { - accessed_pointer_[snode] = stmt; - } else { - accessed_pointer_[snode] = nullptr; // not loop-unique - } + auto accessed_ptr = accessed_pointer_.find(snode); + if (accessed_ptr == accessed_pointer_.end()) { + if (loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt)) { + accessed_pointer_[snode] = stmt; } else { - if (!irpass::analysis::definitely_same_address(accessed_ptr->second, - stmt)) { - accessed_ptr->second = nullptr; // not uniquely accessed - } + accessed_pointer_[snode] = nullptr; // not loop-unique + } + } else { + if (!irpass::analysis::definitely_same_address(accessed_ptr->second, + stmt)) { + accessed_ptr->second = nullptr; // not uniquely accessed } } } @@ -233,50 +228,48 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor { // A memory location of an ExternalPtrStmt depends on the indices // If the accessed indices are loop unique, // the accessed memory location is loop unique - for (auto base_ptr : stmt->base_ptrs.data) { - ArgLoadStmt *arg_load_stmt = base_ptr->as(); - int arg_id = arg_load_stmt->arg_id; + ArgLoadStmt *arg_load_stmt = stmt->base_ptr->as(); + int arg_id = arg_load_stmt->arg_id; - auto accessed_ptr = accessed_arr_pointer_.find(arg_id); + auto accessed_ptr = accessed_arr_pointer_.find(arg_id); - bool stmt_loop_unique = - loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt); + bool stmt_loop_unique = + loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt); - if (!stmt_loop_unique) { - accessed_arr_pointer_[arg_id] = nullptr; // not loop-unique + if (!stmt_loop_unique) { + accessed_arr_pointer_[arg_id] = nullptr; // not loop-unique + } else { + if (accessed_ptr == accessed_arr_pointer_.end()) { + // First time using arr @ arg_id + accessed_arr_pointer_[arg_id] = stmt; } else { - if (accessed_ptr == accessed_arr_pointer_.end()) { - // First time using arr @ arg_id - accessed_arr_pointer_[arg_id] = stmt; - } else { - /** - * We know stmt->base_ptr and the previously recorded pointers - * are loop-unique. We need to figure out whether their loop-unique - * indices are the same while ignoring the others. - * e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed - * a[i, j, 1] and a[j, i, 2] are not uniquely accessed - * a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed - * This is a bit stricter than needed. - * e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed - * However this is probably not common and improvements can be made - * in a future patch. - */ - if (accessed_ptr->second) { - ExternalPtrStmt *other_ptr = accessed_ptr->second; - TI_ASSERT(stmt->indices.size() == other_ptr->indices.size()); - for (int axis = 0; axis < stmt->indices.size(); axis++) { - Stmt *this_index = stmt->indices[axis]; - Stmt *other_index = other_ptr->indices[axis]; - // We only compare unique indices here. - // Since both pointers are loop-unique, all the unique indices - // need to be the same for both to be uniquely accessed - if (loop_unique_stmt_searcher_.is_partially_loop_unique( - this_index)) { - if (!irpass::analysis::same_value(this_index, other_index)) { - // Not equal -> not uniquely accessed - accessed_arr_pointer_[arg_id] = nullptr; - break; - } + /** + * We know stmt->base_ptr and the previously recorded pointers + * are loop-unique. We need to figure out whether their loop-unique + * indices are the same while ignoring the others. + * e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed + * a[i, j, 1] and a[j, i, 2] are not uniquely accessed + * a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed + * This is a bit stricter than needed. + * e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed + * However this is probably not common and improvements can be made + * in a future patch. + */ + if (accessed_ptr->second) { + ExternalPtrStmt *other_ptr = accessed_ptr->second; + TI_ASSERT(stmt->indices.size() == other_ptr->indices.size()); + for (int axis = 0; axis < stmt->indices.size(); axis++) { + Stmt *this_index = stmt->indices[axis]; + Stmt *other_index = other_ptr->indices[axis]; + // We only compare unique indices here. + // Since both pointers are loop-unique, all the unique indices + // need to be the same for both to be uniquely accessed + if (loop_unique_stmt_searcher_.is_partially_loop_unique( + this_index)) { + if (!irpass::analysis::same_value(this_index, other_index)) { + // Not equal -> not uniquely accessed + accessed_arr_pointer_[arg_id] = nullptr; + break; } } } diff --git a/taichi/analysis/mesh_bls_analyzer.cpp b/taichi/analysis/mesh_bls_analyzer.cpp index c64bedf310148..91ebc35b48da1 100644 --- a/taichi/analysis/mesh_bls_analyzer.cpp +++ b/taichi/analysis/mesh_bls_analyzer.cpp @@ -35,7 +35,7 @@ void MeshBLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) { auto idx = conv->idx; if (conv_type == mesh::ConvType::g2r) return; - auto snode = ptr->snodes[0]; + auto snode = ptr->snode; if (!caches_->has(snode)) { if (auto_mesh_local_ && (flag == AccessFlag::accumulate || diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index c1e132cb4d15d..9022e4cb9fd46 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -136,9 +136,8 @@ class IRNodeComparator : public IRVisitor { // And we cannot use irpass::analysis::definitely_same_address() // directly because that function does not support id_map. - // TODO: Update this part if GlobalPtrStmt comes to have more fields - if (stmt->as()->snodes[0]->id != - other->as()->snodes[0]->id) { + if (stmt->as()->snode->id != + other->as()->snode->id) { same = false; return; } diff --git a/taichi/codegen/cc/codegen_cc.cpp b/taichi/codegen/cc/codegen_cc.cpp index 9a18de7742ba7..2bff22b9e6b1d 100644 --- a/taichi/codegen/cc/codegen_cc.cpp +++ b/taichi/codegen/cc/codegen_cc.cpp @@ -152,7 +152,7 @@ class CCTransformer : public IRVisitor { void visit(ExternalPtrStmt *stmt) override { std::string offset = "0"; - const auto *argload = stmt->base_ptrs[0]->as(); + const auto *argload = stmt->base_ptr->as(); const int arg_id = argload->arg_id; const auto element_shape = stmt->element_shape; const auto layout = stmt->element_dim < 0 ? ExternalArrayLayout::kAOS @@ -177,7 +177,7 @@ class CCTransformer : public IRVisitor { auto var = define_var(cc_data_type_name(stmt->element_type().ptr_removed()) + " *", stmt->raw_name()); - emit("{} = {} + {};", var, stmt->base_ptrs[0]->raw_name(), offset); + emit("{} = {} + {};", var, stmt->base_ptr->raw_name(), offset); } void visit(ArgLoadStmt *stmt) override { diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 5101b82c65c5a..0eea94a615bea 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1767,7 +1767,7 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { } void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) { - auto argload = stmt->base_ptrs[0]->as(); + auto argload = stmt->base_ptr->as(); auto arg_id = argload->arg_id; int num_indices = stmt->indices.size(); std::vector sizes(num_indices); @@ -1787,7 +1787,7 @@ void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) { auto dt = stmt->ret_type.ptr_removed(); auto base_ty = tlctx->get_data_type(dt); - auto base = builder->CreateBitCast(llvm_val[stmt->base_ptrs[0]], + auto base = builder->CreateBitCast(llvm_val[stmt->base_ptr], llvm::PointerType::get(base_ty, 0)); auto linear_index = tlctx->get_constant(0); diff --git a/taichi/codegen/metal/codegen_metal.cpp b/taichi/codegen/metal/codegen_metal.cpp index f420fe3ae3221..64ffd7daa4a7e 100644 --- a/taichi/codegen/metal/codegen_metal.cpp +++ b/taichi/codegen/metal/codegen_metal.cpp @@ -458,7 +458,7 @@ class KernelCodegenImpl : public IRVisitor { emit("{{"); { ScopedIndent s(current_appender()); - const auto *argload = stmt->base_ptrs[0]->as(); + const auto *argload = stmt->base_ptr->as(); const int arg_id = argload->arg_id; const int num_indices = stmt->indices.size(); const auto &element_shape = stmt->element_shape; @@ -492,7 +492,7 @@ class KernelCodegenImpl : public IRVisitor { const auto dt = metal_data_type_name(stmt->element_type()); emit("device {} *{} = ({} + {});", dt, stmt->raw_name(), - stmt->base_ptrs[0]->raw_name(), linear_index_name); + stmt->base_ptr->raw_name(), linear_index_name); } void visit(GlobalTemporaryStmt *stmt) override { diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index ae794ff99a6d8..d99bc53810b13 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -588,7 +588,7 @@ class TaskCodegen : public IRVisitor { // Used mostly for transferring data between host (e.g. numpy array) and // device. spirv::Value linear_offset = ir_->int_immediate_number(ir_->i32_type(), 0); - const auto *argload = stmt->base_ptrs[0]->as(); + const auto *argload = stmt->base_ptr->as(); const int arg_id = argload->arg_id; { const int num_indices = stmt->indices.size(); diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 282c22f7092ea..1a6949fa5f44d 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -335,22 +335,14 @@ void CFGNode::gather_loaded_snodes(std::unordered_set &snodes) const { if (auto global_ptr = load_ptr->cast()) { // Avoid computing the UD-chain if every SNode in this global ptr // are already loaded because it can be time-consuming. - bool already_loaded = true; - for (auto &snode : global_ptr->snodes.data) { - if (snodes.count(snode) == 0) { - already_loaded = false; - break; - } - } - if (already_loaded) { + auto snode = global_ptr->snode; + if (snodes.count(snode) > 0) { continue; } if (reach_in.find(global_ptr) != reach_in.end() && !contain_variable(killed_in_this_node, global_ptr)) { // The UD-chain contains the value before this offloaded task. - for (auto &snode : global_ptr->snodes.data) { - snodes.insert(snode); - } + snodes.insert(snode); } } } @@ -458,9 +450,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { continue; } else if (!is_parallel_executed || (atomic->dest->is() && - atomic->dest->as() - ->snodes[0] - ->is_scalar())) { + atomic->dest->as()->snode->is_scalar())) { // If this node is parallel executed, we can't weaken a global // atomic operation to a global load. // TODO: we can weaken it if it's element-wise (i.e. never @@ -704,9 +694,7 @@ void ControlFlowGraph::live_variable_analysis( } if (auto *gptr = stmt->cast(); gptr && config_opt.has_value()) { - TI_ASSERT(gptr->snodes.size() == 1); - const bool res = - (config_opt->eliminable_snodes.count(gptr->snodes[0]) == 0); + const bool res = (config_opt->eliminable_snodes.count(gptr->snode) == 0); return res; } // A global pointer that may be loaded after this kernel. @@ -874,9 +862,7 @@ std::unordered_set ControlFlowGraph::gather_loaded_snodes() { // Therefore we include the nodes[final_node]->reach_in in snodes. for (auto &stmt : nodes[final_node]->reach_in) { if (auto global_ptr = stmt->cast()) { - for (auto &snode : global_ptr->snodes.data) { - snodes.insert(snode); - } + snodes.insert(global_ptr->snode); } } diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 80a36ea29c6bf..4909b40f4bf9e 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -32,78 +32,35 @@ bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const { return false; } -ExternalPtrStmt::ExternalPtrStmt(const LaneAttribute &base_ptrs, +ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr, const std::vector &indices) - : base_ptrs(base_ptrs), indices(indices) { - DataType dt = PrimitiveType::f32; - for (int i = 0; i < (int)base_ptrs.size(); i++) { - TI_ASSERT(base_ptrs[i] != nullptr); - TI_ASSERT(base_ptrs[i]->is()); - } - TI_ASSERT(base_ptrs.size() == 1); - element_type() = dt; + : base_ptr(base_ptr), indices(indices) { + TI_ASSERT(base_ptr != nullptr); + TI_ASSERT(base_ptr->is()); TI_STMT_REG_FIELDS; } -ExternalPtrStmt::ExternalPtrStmt(const LaneAttribute &base_ptrs, +ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr, const std::vector &indices, const std::vector &element_shape, int element_dim) - : ExternalPtrStmt(base_ptrs, indices) { + : ExternalPtrStmt(base_ptr, indices) { this->element_shape = element_shape; this->element_dim = element_dim; } -GlobalPtrStmt::GlobalPtrStmt(const LaneAttribute &snodes, +GlobalPtrStmt::GlobalPtrStmt(SNode *snode, const std::vector &indices, bool activate) - : snodes(snodes), + : snode(snode), indices(indices), activate(activate), is_bit_vectorized(false) { - for (int i = 0; i < (int)snodes.size(); i++) { - TI_ASSERT(snodes[i] != nullptr); - TI_ASSERT(snodes[0]->dt == snodes[i]->dt); - } - TI_ASSERT(snodes.size() == 1); - element_type() = snodes[0]->dt; + TI_ASSERT(snode != nullptr); + element_type() = snode->dt; TI_STMT_REG_FIELDS; } -bool GlobalPtrStmt::is_element_wise(const SNode *snode) const { - if (snode == nullptr) { - // check every SNode when "snode" is nullptr - for (const auto &snode_i : snodes.data) { - if (!is_element_wise(snode_i)) { - return false; - } - } - return true; - } - // check if this statement is element-wise on a specific SNode, i.e., argument - // "snode" - for (int i = 0; i < (int)indices.size(); i++) { - if (auto loop_index_i = indices[i]->cast(); - !(loop_index_i && loop_index_i->loop->is() && - loop_index_i->index == snode->physical_index_position[i])) { - return false; - } - } - return true; -} - -bool GlobalPtrStmt::covers_snode(const SNode *snode) const { - // Check if the addresses of this statement all over the loop cover - // all active indices of the snode. - for (auto &index : indices) { - if (auto loop_unique = index->cast()) { - if (loop_unique->covers_snode(snode)) - return true; - } - } - return is_element_wise(snode); -} - PtrOffsetStmt::PtrOffsetStmt(Stmt *origin_input, Stmt *offset_input) { origin = origin_input; offset = offset_input; @@ -168,14 +125,6 @@ LoopUniqueStmt::LoopUniqueStmt(Stmt *input, const std::vector &covers) TI_STMT_REG_FIELDS; } -bool LoopUniqueStmt::covers_snode(const SNode *snode) const { - if (snode->is_place()) { - return covers.count(snode->parent->id) > 0; - } else { - TI_NOT_IMPLEMENTED - } -} - Stmt *LocalLoadStmt::previous_store_or_alloca_in_block() { int position = parent->locate(this); // TI_ASSERT(width() == 1); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 9dd71663cc0a5..1bb042cf78846 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -292,22 +292,21 @@ class AtomicOpStmt : public Stmt { }; /** - * An external pointer. |base_ptrs| should be ArgLoadStmts with + * An external pointer. |base_ptr| should be ArgLoadStmt with * |is_ptr| == true. */ class ExternalPtrStmt : public Stmt { public: - LaneAttribute base_ptrs; + Stmt *base_ptr; std::vector indices; std::vector element_shape; // AOS: element_dim < 0 // SOA: element_dim > 0 int element_dim; - ExternalPtrStmt(const LaneAttribute &base_ptrs, - const std::vector &indices); + ExternalPtrStmt(Stmt *base_ptr, const std::vector &indices); - ExternalPtrStmt(const LaneAttribute &base_ptrs, + ExternalPtrStmt(Stmt *base_ptr, const std::vector &indices, const std::vector &element_shape, int element_dim); @@ -316,7 +315,7 @@ class ExternalPtrStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type, base_ptrs, indices); + TI_STMT_DEF_FIELDS(ret_type, base_ptr, indices); TI_DEFINE_ACCEPT_AND_CLONE }; @@ -330,19 +329,15 @@ class ExternalPtrStmt : public Stmt { */ class GlobalPtrStmt : public Stmt { public: - LaneAttribute snodes; + SNode *snode; std::vector indices; bool activate; bool is_bit_vectorized; // for bit_loop_vectorize pass - GlobalPtrStmt(const LaneAttribute &snodes, + GlobalPtrStmt(SNode *snode, const std::vector &indices, bool activate = true); - bool is_element_wise(const SNode *snode) const; - - bool covers_snode(const SNode *snode) const; - bool has_global_side_effect() const override { return activate; } @@ -351,7 +346,7 @@ class GlobalPtrStmt : public Stmt { return true; } - TI_STMT_DEF_FIELDS(ret_type, snodes, indices, activate, is_bit_vectorized); + TI_STMT_DEF_FIELDS(ret_type, snode, indices, activate, is_bit_vectorized); TI_DEFINE_ACCEPT_AND_CLONE }; @@ -540,8 +535,6 @@ class LoopUniqueStmt : public Stmt { LoopUniqueStmt(Stmt *input, const std::vector &covers); - bool covers_snode(const SNode *snode) const; - bool has_global_side_effect() const override { return false; } diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 7d12db7d83c5f..d58cd7aed24b5 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -35,11 +35,8 @@ class IndependentBlocksJudger : public BasicStmtVisitor { if (is_inside_loop_) return; TI_ASSERT(stmt->dest->is()); - for (const auto &node : stmt->dest->cast()->snodes.data) { - if (node->has_adjoint()) { - qualified_atomics_ = false; - break; - } + if (stmt->dest->as()->snode->has_adjoint()) { + qualified_atomics_ = false; } } @@ -1001,18 +998,18 @@ class MakeAdjoint : public ADTransform { src = stmt->src->as(); } - auto snodes = src->snodes; - if (!snodes[0]->has_adjoint()) { + auto snode = src->snode; + if (!snode->has_adjoint()) { // No adjoint SNode. Do nothing return; } - if (gradients_stopped(stmt, snodes[0])) { + if (gradients_stopped(stmt, snode)) { // gradients stopped, do nothing. return; } - TI_ASSERT(snodes[0]->get_adjoint() != nullptr); - snodes[0] = snodes[0]->get_adjoint(); - auto adj_ptr = insert(snodes, src->indices); + TI_ASSERT(snode->get_adjoint() != nullptr); + snode = snode->get_adjoint(); + auto adj_ptr = insert(snode, src->indices); if (is_ptr_offset) { adj_ptr = insert(adj_ptr, stmt->src->as()->offset); @@ -1037,14 +1034,14 @@ class MakeAdjoint : public ADTransform { dest = stmt->dest->as(); } - auto snodes = dest->snodes; - if (!snodes[0]->has_adjoint()) { + auto snode = dest->snode; + if (!snode->has_adjoint()) { // no gradient (likely integer types) return; } - TI_ASSERT(snodes[0]->get_adjoint() != nullptr); - snodes[0] = snodes[0]->get_adjoint(); - auto adjoint_ptr = insert(snodes, dest->indices); + TI_ASSERT(snode->get_adjoint() != nullptr); + snode = snode->get_adjoint(); + auto adjoint_ptr = insert(snode, dest->indices); if (is_ptr_offset) { adjoint_ptr = insert( adjoint_ptr, stmt->dest->as()->offset); @@ -1064,15 +1061,15 @@ class MakeAdjoint : public ADTransform { dest = stmt->dest->as(); } - auto snodes = dest->snodes; - if (!snodes[0]->has_adjoint()) { + auto snode = dest->snode; + if (!snode->has_adjoint()) { // no gradient (likely integer types) return; } - TI_ASSERT(snodes[0]->get_adjoint() != nullptr); - snodes[0] = snodes[0]->get_adjoint(); - auto adjoint_ptr = insert(snodes, dest->indices); + TI_ASSERT(snode->get_adjoint() != nullptr); + snode = snode->get_adjoint(); + auto adjoint_ptr = insert(snode, dest->indices); if (is_ptr_offset) { adjoint_ptr = insert( adjoint_ptr, stmt->dest->as()->offset); @@ -1327,18 +1324,18 @@ class MakeDual : public ADTransform { } else { src = stmt->src->as(); } - auto snodes = src->snodes; - if (!snodes[0]->has_dual()) { + auto snode = src->snode; + if (!snode->has_dual()) { // No dual SNode. Do nothing return; } - if (gradients_stopped(stmt, snodes[0])) { + if (gradients_stopped(stmt, snode)) { // gradients stopped, do nothing. return; } - TI_ASSERT(snodes[0]->get_dual() != nullptr); - snodes[0] = snodes[0]->get_dual(); - auto dual_ptr = insert(snodes, src->indices); + TI_ASSERT(snode->get_dual() != nullptr); + snode = snode->get_dual(); + auto dual_ptr = insert(snode, src->indices); if (is_ptr_offset) { dual_ptr = insert(dual_ptr, stmt->src->as()->offset); @@ -1355,14 +1352,14 @@ class MakeDual : public ADTransform { } else { dest = stmt->dest->as(); } - auto snodes = dest->snodes; - if (!snodes[0]->has_dual()) { + auto snode = dest->snode; + if (!snode->has_dual()) { // no gradient (likely integer types) return; } - TI_ASSERT(snodes[0]->get_dual() != nullptr); - snodes[0] = snodes[0]->get_dual(); - auto dual_ptr = insert(snodes, dest->indices); + TI_ASSERT(snode->get_dual() != nullptr); + snode = snode->get_dual(); + auto dual_ptr = insert(snode, dest->indices); if (is_ptr_offset) { dual_ptr = insert(dual_ptr, stmt->dest->as()->offset); @@ -1379,14 +1376,14 @@ class MakeDual : public ADTransform { } else { dest = stmt->dest->as(); } - auto snodes = dest->snodes; - if (!snodes[0]->has_dual()) { + auto snode = dest->snode; + if (!snode->has_dual()) { // no gradient (likely integer types) return; } - TI_ASSERT(snodes[0]->get_dual() != nullptr); - snodes[0] = snodes[0]->get_dual(); - auto dual_ptr = insert(snodes, dest->indices); + TI_ASSERT(snode->get_dual() != nullptr); + snode = snode->get_dual(); + auto dual_ptr = insert(snode, dest->indices); if (is_ptr_offset) { dual_ptr = insert(dual_ptr, stmt->dest->as()->offset); @@ -1550,28 +1547,28 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor { void visit(GlobalLoadStmt *stmt) override { GlobalPtrStmt *src = stmt->src->as(); - auto snodes = src->snodes; - if (!snodes[0]->has_adjoint_checkbit()) { + auto snode = src->snode; + if (!snode->has_adjoint_checkbit()) { return; } - TI_ASSERT(snodes[0]->get_adjoint_checkbit() != nullptr); - snodes[0] = snodes[0]->get_adjoint_checkbit(); + TI_ASSERT(snode->get_adjoint_checkbit() != nullptr); + snode = snode->get_adjoint_checkbit(); auto gloabl_ptr = - stmt->insert_after_me(Stmt::make(snodes, src->indices)); + stmt->insert_after_me(Stmt::make(snode, src->indices)); auto one = gloabl_ptr->insert_after_me( Stmt::make(LaneAttribute(1))); one->insert_after_me(Stmt::make(gloabl_ptr, one)); } void visit_gloabl_store_stmt_and_atomic_add(Stmt *stmt, GlobalPtrStmt *dest) { - auto snodes = dest->snodes; - if (!snodes[0]->has_adjoint_checkbit()) { + auto snode = dest->snode; + if (!snode->has_adjoint_checkbit()) { return; } - TI_ASSERT(snodes[0]->get_adjoint_checkbit() != nullptr); - snodes[0] = snodes[0]->get_adjoint_checkbit(); - auto global_ptr = stmt->insert_before_me( - Stmt::make(snodes, dest->indices)); + TI_ASSERT(snode->get_adjoint_checkbit() != nullptr); + snode = snode->get_adjoint_checkbit(); + auto global_ptr = + stmt->insert_before_me(Stmt::make(snode, dest->indices)); auto global_load = stmt->insert_before_me(Stmt::make(global_ptr)); auto zero = stmt->insert_before_me( @@ -1581,7 +1578,7 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor { std::string msg = fmt::format( "(kernel={}) Breaks the global data access rule. Snode {} is " "overwritten unexpectedly.", - kernel_name_, dest->snodes[0]->get_node_type_name()); + kernel_name_, dest->snode->get_node_type_name()); msg += "\n" + stmt->tb; stmt->insert_before_me( diff --git a/taichi/transforms/bit_loop_vectorize.cpp b/taichi/transforms/bit_loop_vectorize.cpp index 51a88fe4fdce7..d2448919d6d4a 100644 --- a/taichi/transforms/bit_loop_vectorize.cpp +++ b/taichi/transforms/bit_loop_vectorize.cpp @@ -61,7 +61,7 @@ class BitLoopVectorize : public IRVisitor { auto indices = ptr->indices; indices[1] = loop_stmt->body->statements[1].get(); auto base_ptr = - std::make_unique(ptr->snodes, indices); + std::make_unique(ptr->snode, indices); base_ptr->ret_type = new_ret_type; base_ptr->is_bit_vectorized = true; // load x[i, j](base) @@ -80,7 +80,7 @@ class BitLoopVectorize : public IRVisitor { offset_index_opcode, indices[1], offset_constant.get()); indices[1] = offset_index.get(); auto offset_ptr = - std::make_unique(ptr->snodes, indices); + std::make_unique(ptr->snode, indices); offset_ptr->ret_type = new_ret_type; offset_ptr->is_bit_vectorized = true; auto load_offsetted = diff --git a/taichi/transforms/check_out_of_bound.cpp b/taichi/transforms/check_out_of_bound.cpp index 47e2105e954bb..32734bf55cb10 100644 --- a/taichi/transforms/check_out_of_bound.cpp +++ b/taichi/transforms/check_out_of_bound.cpp @@ -43,8 +43,7 @@ class CheckOutOfBound : public BasicStmtVisitor { void visit(GlobalPtrStmt *stmt) override { if (is_done(stmt)) return; - TI_ASSERT(stmt->snodes.size() == 1); - auto snode = stmt->snodes[0]; + auto snode = stmt->snode; bool has_offset = !(snode->index_offsets.empty()); auto new_stmts = VecStatement(); auto zero = new_stmts.push_back(LaneAttribute(0)); diff --git a/taichi/transforms/demote_atomics.cpp b/taichi/transforms/demote_atomics.cpp index 93118d9ed9de5..14b143e7b27d3 100644 --- a/taichi/transforms/demote_atomics.cpp +++ b/taichi/transforms/demote_atomics.cpp @@ -46,21 +46,18 @@ class DemoteAtomics : public BasicStmtVisitor { if (stmt->dest->is()) { demote = true; auto dest = stmt->dest->as(); - for (auto snode : dest->snodes.data) { - if (loop_unique_ptr_[snode] == nullptr || - loop_unique_ptr_[snode]->indices.empty()) { - // not uniquely accessed - demote = false; - break; - } - if (current_offloaded->mem_access_opt.has_flag( - snode, SNodeAccessFlag::block_local) || - current_offloaded->mem_access_opt.has_flag( - snode, SNodeAccessFlag::mesh_local)) { - // BLS does not support write access yet so we keep atomic_adds. - demote = false; - break; - } + auto snode = dest->snode; + if (loop_unique_ptr_[snode] == nullptr || + loop_unique_ptr_[snode]->indices.empty()) { + // not uniquely accessed + demote = false; + } + if (current_offloaded->mem_access_opt.has_flag( + snode, SNodeAccessFlag::block_local) || + current_offloaded->mem_access_opt.has_flag( + snode, SNodeAccessFlag::mesh_local)) { + // BLS does not support write access yet so we keep atomic_adds. + demote = false; } // demote from-end atomics if (current_offloaded->task_type == OffloadedTaskType::mesh_for) { @@ -73,8 +70,8 @@ class DemoteAtomics : public BasicStmtVisitor { } if (idx->is() && idx->as()->is_mesh_index() && - loop_unique_ptr_[stmt->dest->as() - ->snodes.data[0]] != nullptr) { + loop_unique_ptr_[stmt->dest->as()->snode] != + nullptr) { demote = true; } } @@ -85,13 +82,11 @@ class DemoteAtomics : public BasicStmtVisitor { if (dest_ptr->indices.empty()) { demote = false; } - for (Stmt *base_stmt : dest_ptr->base_ptrs.data) { - ArgLoadStmt *arg_load_stmt = base_stmt->as(); - int arg_id = arg_load_stmt->arg_id; - if (loop_unique_arr_ptr_[arg_id] == nullptr) { - // Not loop unique - demote = false; - } + ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as(); + int arg_id = arg_load_stmt->arg_id; + if (loop_unique_arr_ptr_[arg_id] == nullptr) { + // Not loop unique + demote = false; } // TODO: Is BLS / Mem Access Opt a thing for any_arr? } diff --git a/taichi/transforms/detect_read_only.cpp b/taichi/transforms/detect_read_only.cpp index fb5d5573a2795..c0e2825d4b36a 100644 --- a/taichi/transforms/detect_read_only.cpp +++ b/taichi/transforms/detect_read_only.cpp @@ -38,7 +38,7 @@ class ExternalPtrAccessVisitor : public BasicStmtVisitor { return; ExternalPtrStmt *src = stmt->src->cast(); - ArgLoadStmt *arg = src->base_ptrs.data[0]->cast(); + ArgLoadStmt *arg = src->base_ptr->cast(); if (map_.find(arg->arg_id) != map_.end()) { map_[arg->arg_id] = map_[arg->arg_id] | ExternalPtrAccess::READ; } else { @@ -51,7 +51,7 @@ class ExternalPtrAccessVisitor : public BasicStmtVisitor { return; ExternalPtrStmt *dst = stmt->dest->cast(); - ArgLoadStmt *arg = dst->base_ptrs.data[0]->cast(); + ArgLoadStmt *arg = dst->base_ptr->cast(); if (map_.find(arg->arg_id) != map_.end()) { map_[arg->arg_id] = map_[arg->arg_id] | ExternalPtrAccess::WRITE; } else { @@ -65,7 +65,7 @@ class ExternalPtrAccessVisitor : public BasicStmtVisitor { // Atomics modifies existing state (therefore both read & write) ExternalPtrStmt *dst = stmt->dest->cast(); - ArgLoadStmt *arg = dst->base_ptrs.data[0]->cast(); + ArgLoadStmt *arg = dst->base_ptr->cast(); map_[arg->arg_id] = ExternalPtrAccess::WRITE | ExternalPtrAccess::READ; } }; diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index f27e4de7cf961..bf754b518dc1e 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -133,30 +133,28 @@ class WeakenAccess : public BasicStmtVisitor { current_struct_for_; if (is_struct_for) { bool same_as_loop_snode = true; - for (auto snode : stmt->snodes.data) { - SNode *loop_snode = nullptr; - if (current_struct_for_) { - loop_snode = current_struct_for_->snode; - } else { - loop_snode = current_offload_->snode; - } - TI_ASSERT(loop_snode); - if (!share_sparsity(snode, loop_snode)) { - same_as_loop_snode = false; - } - if (stmt->indices.size() == loop_snode->num_active_indices) - for (int i = 0; i < loop_snode->num_active_indices; i++) { - auto ind = stmt->indices[i]; - // TODO: vectorized cases? - if (auto loop_var = ind->cast()) { - if (loop_var->index != i) { - same_as_loop_snode = false; - } - } else { + SNode *loop_snode = nullptr; + if (current_struct_for_) { + loop_snode = current_struct_for_->snode; + } else { + loop_snode = current_offload_->snode; + } + TI_ASSERT(loop_snode); + if (!share_sparsity(stmt->snode, loop_snode)) { + same_as_loop_snode = false; + } + if (stmt->indices.size() == loop_snode->num_active_indices) + for (int i = 0; i < loop_snode->num_active_indices; i++) { + auto ind = stmt->indices[i]; + // TODO: vectorized cases? + if (auto loop_var = ind->cast()) { + if (loop_var->index != i) { same_as_loop_snode = false; } + } else { + same_as_loop_snode = false; } - } + } if (same_as_loop_snode) stmt->activate = false; } diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 23b5064832454..ffa5decd26e34 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -392,8 +392,8 @@ class IRPrinter : public IRVisitor { fmt::format("{}{} = global ptr [", stmt->type_hint(), stmt->name()); std::string snode_name; - if (stmt->snodes[0]) { - snode_name = stmt->snodes[0]->get_node_type_name_hinted(); + if (stmt->snode) { + snode_name = stmt->snode->get_node_type_name_hinted(); } else { snode_name = "unknown"; } @@ -528,14 +528,8 @@ class IRPrinter : public IRVisitor { } void visit(ExternalPtrStmt *stmt) override { - std::string s = "<"; - for (int i = 0; i < (int)stmt->base_ptrs.size(); i++) { - s += fmt::format("{}", stmt->base_ptrs[i]->name()); - if (i + 1 < (int)stmt->base_ptrs.size()) { - s += ", "; - } - } - s += ">, ["; + std::string s = stmt->base_ptr->name(); + s += ", ["; for (int i = 0; i < (int)stmt->indices.size(); i++) { s += fmt::format("{}", stmt->indices[i]->name()); if (i + 1 < (int)stmt->indices.size()) { diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index c284cf440768b..1597fb68157b5 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -99,8 +99,9 @@ class LowerAccess : public IRVisitor { // For ti.is_active TI_ASSERT(!activate); } - PtrLowererImpl lowerer{ptr->snodes[0], ptr->indices, snode_op, - ptr->is_bit_vectorized, &lowered, packed}; + PtrLowererImpl lowerer{ptr->snode, ptr->indices, + snode_op, ptr->is_bit_vectorized, + &lowered, packed}; lowerer.set_pointer_needs_activation(activate); lowerer.set_lower_access(this); lowerer.run(); @@ -109,12 +110,12 @@ class LowerAccess : public IRVisitor { if (ptr->is_bit_vectorized) { // if the global ptr is bit vectorized, we start from the place snode // and find the parent quant array snode, use its physical type - auto parent_ret_type = ptr->snodes[0]->parent->physical_type; + auto parent_ret_type = ptr->snode->parent->physical_type; auto ptr_ret_type = TypeFactory::get_instance().get_pointer_type(parent_ret_type); lowered_ptr->ret_type = DataType(ptr_ret_type); } else { - lowered_ptr->ret_type = ptr->snodes[0]->dt; + lowered_ptr->ret_type = ptr->snode->dt; } return lowered; } diff --git a/taichi/transforms/make_block_local.cpp b/taichi/transforms/make_block_local.cpp index 18aa2b2e7c7eb..eb00a89afd712 100644 --- a/taichi/transforms/make_block_local.cpp +++ b/taichi/transforms/make_block_local.cpp @@ -216,7 +216,7 @@ void make_block_local_offload(OffloadedStmt *offload, // TODO: no more abuse of gather_statements... irpass::analysis::gather_statements(offload->body.get(), [&](Stmt *stmt) { if (auto global_ptr = stmt->cast()) { - if (global_ptr->snodes[0] == snode) { + if (global_ptr->snode == snode) { global_ptrs.push_back(global_ptr); } } diff --git a/taichi/transforms/make_mesh_block_local.cpp b/taichi/transforms/make_mesh_block_local.cpp index 5208ca55bb38b..c7f3d689ebadc 100644 --- a/taichi/transforms/make_mesh_block_local.cpp +++ b/taichi/transforms/make_mesh_block_local.cpp @@ -100,7 +100,7 @@ void MakeMeshBlockLocal::replace_global_ptrs(SNode *snode) { std::vector global_ptrs; irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { if (auto global_ptr = stmt->cast()) { - if (global_ptr->snodes[0] == snode && + if (global_ptr->snode == snode && global_ptr->indices[0]->is()) { global_ptrs.push_back(global_ptr); } diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp index c6b371b7ea0ec..d6e1ca852138d 100644 --- a/taichi/transforms/make_thread_local.cpp +++ b/taichi/transforms/make_thread_local.cpp @@ -106,9 +106,8 @@ void make_thread_local_offload(OffloadedStmt *offload) { // We can only optimized reductions to global ptrs with form like // loss[None] (0-D fields) for now. // No TLS on quant types. - return (dest->snodes[0]->type == SNodeType::place) && - dest->indices.empty() && - dest->snodes[0]->dt->is(); + return (dest->snode->type == SNodeType::place) && + dest->indices.empty() && dest->snode->dt->is(); }); auto valid_global_tmps = find_global_reduction_destinations( diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index d0ba541c0661a..78b4868e94747 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -144,20 +144,17 @@ class TypeCheck : public IRVisitor { return; } stmt->ret_type.set_is_pointer(true); - if (stmt->snodes) { + if (stmt->snode) { stmt->ret_type = - TypeFactory::get_instance().get_pointer_type(stmt->snodes[0]->dt); + TypeFactory::get_instance().get_pointer_type(stmt->snode->dt); } else TI_WARN("[{}] Type inference failed: snode is nullptr.\n{}", stmt->name(), stmt->tb); - for (int l = 0; l < stmt->snodes.size(); l++) { - if (stmt->snodes[l]->parent->num_active_indices != 0 && - stmt->snodes[l]->parent->num_active_indices != stmt->indices.size()) { - TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(), - stmt->snodes[l]->parent->node_type_name, - stmt->snodes[l]->parent->num_active_indices, - stmt->indices.size()); - } + if (stmt->snode->parent->num_active_indices != 0 && + stmt->snode->parent->num_active_indices != stmt->indices.size()) { + TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(), + stmt->snode->parent->node_type_name, + stmt->snode->parent->num_active_indices, stmt->indices.size()); } for (int i = 0; i < stmt->indices.size(); i++) { if (!is_integral(stmt->indices[i]->ret_type)) { @@ -409,8 +406,7 @@ class TypeCheck : public IRVisitor { void visit(ExternalPtrStmt *stmt) override { stmt->ret_type.set_is_pointer(true); - stmt->ret_type = TypeFactory::create_vector_or_scalar_type( - stmt->base_ptrs.size(), stmt->base_ptrs[0]->ret_type); + stmt->ret_type = stmt->base_ptr->ret_type; for (int i = 0; i < stmt->indices.size(); i++) { TI_ASSERT(is_integral(stmt->indices[i]->ret_type)); if (stmt->indices[i]->ret_type != PrimitiveType::i32) {