From b0cd578e98d38453d9e2cf46f689d956f91a5b33 Mon Sep 17 00:00:00 2001
From: Yi Xu <xy_xuyi@foxmail.com>
Date: Mon, 29 Aug 2022 14:11:50 +0800
Subject: [PATCH 1/4] Handle ExternalPtrStmt

---
 taichi/analysis/alias_analysis.cpp            |  6 +-
 .../gather_uniquely_accessed_pointers.cpp     | 78 +++++++++----------
 taichi/codegen/cc/codegen_cc.cpp              |  4 +-
 taichi/codegen/llvm/codegen_llvm.cpp          |  4 +-
 taichi/codegen/metal/codegen_metal.cpp        |  4 +-
 taichi/codegen/spirv/spirv_codegen.cpp        |  2 +-
 taichi/ir/statements.cpp                      | 17 ++--
 taichi/ir/statements.h                        | 15 ++--
 taichi/transforms/demote_atomics.cpp          | 12 ++-
 taichi/transforms/detect_read_only.cpp        |  6 +-
 taichi/transforms/ir_printer.cpp              | 10 +--
 taichi/transforms/type_check.cpp              |  3 +-
 12 files changed, 71 insertions(+), 90 deletions(-)

diff --git a/taichi/analysis/alias_analysis.cpp b/taichi/analysis/alias_analysis.cpp
index 3e401e3768e67..828fccad48355 100644
--- a/taichi/analysis/alias_analysis.cpp
+++ b/taichi/analysis/alias_analysis.cpp
@@ -96,9 +96,9 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
       return AliasResult::different;
     auto ptr1 = var1->as<ExternalPtrStmt>();
     auto ptr2 = var2->as<ExternalPtrStmt>();
-    if (ptr1->base_ptrs[0] != ptr2->base_ptrs[0]) {
-      auto base1 = ptr1->base_ptrs[0]->as<ArgLoadStmt>();
-      auto base2 = ptr2->base_ptrs[0]->as<ArgLoadStmt>();
+    if (ptr1->base_ptr != ptr2->base_ptr) {
+      auto base1 = ptr1->base_ptr->as<ArgLoadStmt>();
+      auto base2 = ptr2->base_ptr->as<ArgLoadStmt>();
       if (base1->arg_id != base2->arg_id) {
         return AliasResult::different;
       }
diff --git a/taichi/analysis/gather_uniquely_accessed_pointers.cpp b/taichi/analysis/gather_uniquely_accessed_pointers.cpp
index 1d2a770120d89..2347b2b7a01bd 100644
--- a/taichi/analysis/gather_uniquely_accessed_pointers.cpp
+++ b/taichi/analysis/gather_uniquely_accessed_pointers.cpp
@@ -233,50 +233,48 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
     // A memory location of an ExternalPtrStmt depends on the indices
     // If the accessed indices are loop unique,
     // the accessed memory location is loop unique
-    for (auto base_ptr : stmt->base_ptrs.data) {
-      ArgLoadStmt *arg_load_stmt = base_ptr->as<ArgLoadStmt>();
-      int arg_id = arg_load_stmt->arg_id;
+    ArgLoadStmt *arg_load_stmt = stmt->base_ptr->as<ArgLoadStmt>();
+    int arg_id = arg_load_stmt->arg_id;
 
-      auto accessed_ptr = accessed_arr_pointer_.find(arg_id);
+    auto accessed_ptr = accessed_arr_pointer_.find(arg_id);
 
-      bool stmt_loop_unique =
-          loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt);
+    bool stmt_loop_unique =
+        loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt);
 
-      if (!stmt_loop_unique) {
-        accessed_arr_pointer_[arg_id] = nullptr;  // not loop-unique
+    if (!stmt_loop_unique) {
+      accessed_arr_pointer_[arg_id] = nullptr;  // not loop-unique
+    } else {
+      if (accessed_ptr == accessed_arr_pointer_.end()) {
+        // First time using arr @ arg_id
+        accessed_arr_pointer_[arg_id] = stmt;
       } else {
-        if (accessed_ptr == accessed_arr_pointer_.end()) {
-          // First time using arr @ arg_id
-          accessed_arr_pointer_[arg_id] = stmt;
-        } else {
-          /**
-           * We know stmt->base_ptr and the previously recorded pointers
-           * are loop-unique. We need to figure out whether their loop-unique
-           * indices are the same while ignoring the others.
-           * e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed
-           *      a[i, j, 1] and a[j, i, 2] are not uniquely accessed
-           *      a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed
-           * This is a bit stricter than needed.
-           * e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed
-           * However this is probably not common and improvements can be made
-           * in a future patch.
-           */
-          if (accessed_ptr->second) {
-            ExternalPtrStmt *other_ptr = accessed_ptr->second;
-            TI_ASSERT(stmt->indices.size() == other_ptr->indices.size());
-            for (int axis = 0; axis < stmt->indices.size(); axis++) {
-              Stmt *this_index = stmt->indices[axis];
-              Stmt *other_index = other_ptr->indices[axis];
-              // We only compare unique indices here.
-              // Since both pointers are loop-unique, all the unique indices
-              // need to be the same for both to be uniquely accessed
-              if (loop_unique_stmt_searcher_.is_partially_loop_unique(
-                      this_index)) {
-                if (!irpass::analysis::same_value(this_index, other_index)) {
-                  // Not equal -> not uniquely accessed
-                  accessed_arr_pointer_[arg_id] = nullptr;
-                  break;
-                }
+        /**
+         * We know stmt->base_ptr and the previously recorded pointers
+         * are loop-unique. We need to figure out whether their loop-unique
+         * indices are the same while ignoring the others.
+         * e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed
+         *      a[i, j, 1] and a[j, i, 2] are not uniquely accessed
+         *      a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed
+         * This is a bit stricter than needed.
+         * e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed
+         * However this is probably not common and improvements can be made
+         * in a future patch.
+         */
+        if (accessed_ptr->second) {
+          ExternalPtrStmt *other_ptr = accessed_ptr->second;
+          TI_ASSERT(stmt->indices.size() == other_ptr->indices.size());
+          for (int axis = 0; axis < stmt->indices.size(); axis++) {
+            Stmt *this_index = stmt->indices[axis];
+            Stmt *other_index = other_ptr->indices[axis];
+            // We only compare unique indices here.
+            // Since both pointers are loop-unique, all the unique indices
+            // need to be the same for both to be uniquely accessed
+            if (loop_unique_stmt_searcher_.is_partially_loop_unique(
+                    this_index)) {
+              if (!irpass::analysis::same_value(this_index, other_index)) {
+                // Not equal -> not uniquely accessed
+                accessed_arr_pointer_[arg_id] = nullptr;
+                break;
               }
             }
           }
diff --git a/taichi/codegen/cc/codegen_cc.cpp b/taichi/codegen/cc/codegen_cc.cpp
index 9a18de7742ba7..2bff22b9e6b1d 100644
--- a/taichi/codegen/cc/codegen_cc.cpp
+++ b/taichi/codegen/cc/codegen_cc.cpp
@@ -152,7 +152,7 @@ class CCTransformer : public IRVisitor {
 
   void visit(ExternalPtrStmt *stmt) override {
     std::string offset = "0";
-    const auto *argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
+    const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
     const int arg_id = argload->arg_id;
     const auto element_shape = stmt->element_shape;
     const auto layout = stmt->element_dim < 0 ? ExternalArrayLayout::kAOS
@@ -177,7 +177,7 @@ class CCTransformer : public IRVisitor {
     auto var =
         define_var(cc_data_type_name(stmt->element_type().ptr_removed()) + " *",
                    stmt->raw_name());
-    emit("{} = {} + {};", var, stmt->base_ptrs[0]->raw_name(), offset);
+    emit("{} = {} + {};", var, stmt->base_ptr->raw_name(), offset);
   }
 
   void visit(ArgLoadStmt *stmt) override {
diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp
index 5101b82c65c5a..0eea94a615bea 100644
--- a/taichi/codegen/llvm/codegen_llvm.cpp
+++ b/taichi/codegen/llvm/codegen_llvm.cpp
@@ -1767,7 +1767,7 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) {
 }
 
 void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {
-  auto argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
+  auto argload = stmt->base_ptr->as<ArgLoadStmt>();
   auto arg_id = argload->arg_id;
   int num_indices = stmt->indices.size();
   std::vector<llvm::Value *> sizes(num_indices);
@@ -1787,7 +1787,7 @@ void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {
 
   auto dt = stmt->ret_type.ptr_removed();
   auto base_ty = tlctx->get_data_type(dt);
-  auto base = builder->CreateBitCast(llvm_val[stmt->base_ptrs[0]],
+  auto base = builder->CreateBitCast(llvm_val[stmt->base_ptr],
                                      llvm::PointerType::get(base_ty, 0));
 
   auto linear_index = tlctx->get_constant(0);
diff --git a/taichi/codegen/metal/codegen_metal.cpp b/taichi/codegen/metal/codegen_metal.cpp
index f420fe3ae3221..64ffd7daa4a7e 100644
--- a/taichi/codegen/metal/codegen_metal.cpp
+++ b/taichi/codegen/metal/codegen_metal.cpp
@@ -458,7 +458,7 @@ class KernelCodegenImpl : public IRVisitor {
     emit("{{");
     {
       ScopedIndent s(current_appender());
-      const auto *argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
+      const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
       const int arg_id = argload->arg_id;
       const int num_indices = stmt->indices.size();
       const auto &element_shape = stmt->element_shape;
@@ -492,7 +492,7 @@ class KernelCodegenImpl : public IRVisitor {
 
     const auto dt = metal_data_type_name(stmt->element_type());
     emit("device {} *{} = ({} + {});", dt, stmt->raw_name(),
-         stmt->base_ptrs[0]->raw_name(), linear_index_name);
+         stmt->base_ptr->raw_name(), linear_index_name);
   }
 
   void visit(GlobalTemporaryStmt *stmt) override {
diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp
index ae794ff99a6d8..d99bc53810b13 100644
--- a/taichi/codegen/spirv/spirv_codegen.cpp
+++ b/taichi/codegen/spirv/spirv_codegen.cpp
@@ -588,7 +588,7 @@ class TaskCodegen : public IRVisitor {
     // Used mostly for transferring data between host (e.g. numpy array) and
     // device.
     spirv::Value linear_offset = ir_->int_immediate_number(ir_->i32_type(), 0);
-    const auto *argload = stmt->base_ptrs[0]->as<ArgLoadStmt>();
+    const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
     const int arg_id = argload->arg_id;
     {
       const int num_indices = stmt->indices.size();
diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp
index 80a36ea29c6bf..4cfb4fc808681 100644
--- a/taichi/ir/statements.cpp
+++ b/taichi/ir/statements.cpp
@@ -32,24 +32,19 @@ bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const {
   return false;
 }
 
-ExternalPtrStmt::ExternalPtrStmt(const LaneAttribute<Stmt *> &base_ptrs,
+ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr,
                                  const std::vector<Stmt *> &indices)
-    : base_ptrs(base_ptrs), indices(indices) {
-  DataType dt = PrimitiveType::f32;
-  for (int i = 0; i < (int)base_ptrs.size(); i++) {
-    TI_ASSERT(base_ptrs[i] != nullptr);
-    TI_ASSERT(base_ptrs[i]->is<ArgLoadStmt>());
-  }
-  TI_ASSERT(base_ptrs.size() == 1);
-  element_type() = dt;
+    : base_ptr(base_ptr), indices(indices) {
+  TI_ASSERT(base_ptr != nullptr);
+  TI_ASSERT(base_ptr->is<ArgLoadStmt>());
   TI_STMT_REG_FIELDS;
 }
 
-ExternalPtrStmt::ExternalPtrStmt(const LaneAttribute<Stmt *> &base_ptrs,
+ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr,
                                  const std::vector<Stmt *> &indices,
                                  const std::vector<int> &element_shape,
                                  int element_dim)
-    : ExternalPtrStmt(base_ptrs, indices) {
+    : ExternalPtrStmt(base_ptr, indices) {
   this->element_shape = element_shape;
   this->element_dim = element_dim;
 }
diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h
index 9dd71663cc0a5..3baf2e3ee39e4 100644
--- a/taichi/ir/statements.h
+++ b/taichi/ir/statements.h
@@ -292,31 +292,28 @@ class AtomicOpStmt : public Stmt {
 };
 
 /**
- * An external pointer. |base_ptrs| should be ArgLoadStmts with
+ * An external pointer. |base_ptr| should be ArgLoadStmt with
  * |is_ptr| == true.
  */
 class ExternalPtrStmt : public Stmt {
  public:
-  LaneAttribute<Stmt *> base_ptrs;
+  Stmt *base_ptr;
   std::vector<Stmt *> indices;
   std::vector<int> element_shape;
   // AOS: element_dim < 0
   // SOA: element_dim > 0
   int element_dim;
 
-  ExternalPtrStmt(const LaneAttribute<Stmt *> &base_ptrs,
-                  const std::vector<Stmt *> &indices);
+  ExternalPtrStmt(Stmt *base_ptr, const std::vector<Stmt *> &indices);
 
-  ExternalPtrStmt(const LaneAttribute<Stmt *> &base_ptrs,
-                  const std::vector<Stmt *> &indices,
-                  const std::vector<int> &element_shape,
-                  int element_dim);
+  ExternalPtrStmt(Stmt *base_ptr, const std::vector<Stmt *> &indices,
+                  const std::vector<int> &element_shape, int element_dim);
 
   bool has_global_side_effect() const override {
     return false;
   }
 
-  TI_STMT_DEF_FIELDS(ret_type, base_ptrs, indices);
+  TI_STMT_DEF_FIELDS(ret_type, base_ptr, indices);
   TI_DEFINE_ACCEPT_AND_CLONE
 };
 
diff --git a/taichi/transforms/demote_atomics.cpp b/taichi/transforms/demote_atomics.cpp
index 93118d9ed9de5..0b3a3552c666f 100644
--- a/taichi/transforms/demote_atomics.cpp
+++ b/taichi/transforms/demote_atomics.cpp
@@ -85,13 +85,11 @@ class DemoteAtomics : public BasicStmtVisitor {
           if (dest_ptr->indices.empty()) {
             demote = false;
           }
-          for (Stmt *base_stmt : dest_ptr->base_ptrs.data) {
-            ArgLoadStmt *arg_load_stmt = base_stmt->as<ArgLoadStmt>();
-            int arg_id = arg_load_stmt->arg_id;
-            if (loop_unique_arr_ptr_[arg_id] == nullptr) {
-              // Not loop unique
-              demote = false;
-            }
+          ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as<ArgLoadStmt>();
+          int arg_id = arg_load_stmt->arg_id;
+          if (loop_unique_arr_ptr_[arg_id] == nullptr) {
+            // Not loop unique
+            demote = false;
           }
           // TODO: Is BLS / Mem Access Opt a thing for any_arr?
         }
diff --git a/taichi/transforms/detect_read_only.cpp b/taichi/transforms/detect_read_only.cpp
index fb5d5573a2795..c0e2825d4b36a 100644
--- a/taichi/transforms/detect_read_only.cpp
+++ b/taichi/transforms/detect_read_only.cpp
@@ -38,7 +38,7 @@ class ExternalPtrAccessVisitor : public BasicStmtVisitor {
       return;
 
     ExternalPtrStmt *src = stmt->src->cast<ExternalPtrStmt>();
-    ArgLoadStmt *arg = src->base_ptrs.data[0]->cast<ArgLoadStmt>();
+    ArgLoadStmt *arg = src->base_ptr->cast<ArgLoadStmt>();
     if (map_.find(arg->arg_id) != map_.end()) {
       map_[arg->arg_id] = map_[arg->arg_id] | ExternalPtrAccess::READ;
     } else {
@@ -51,7 +51,7 @@ class ExternalPtrAccessVisitor : public BasicStmtVisitor {
       return;
 
     ExternalPtrStmt *dst = stmt->dest->cast<ExternalPtrStmt>();
-    ArgLoadStmt *arg = dst->base_ptrs.data[0]->cast<ArgLoadStmt>();
+    ArgLoadStmt *arg = dst->base_ptr->cast<ArgLoadStmt>();
     if (map_.find(arg->arg_id) != map_.end()) {
       map_[arg->arg_id] = map_[arg->arg_id] | ExternalPtrAccess::WRITE;
     } else {
@@ -65,7 +65,7 @@ class ExternalPtrAccessVisitor : public BasicStmtVisitor {
 
     // Atomics modifies existing state (therefore both read & write)
     ExternalPtrStmt *dst = stmt->dest->cast<ExternalPtrStmt>();
-    ArgLoadStmt *arg = dst->base_ptrs.data[0]->cast<ArgLoadStmt>();
+    ArgLoadStmt *arg = dst->base_ptr->cast<ArgLoadStmt>();
     map_[arg->arg_id] = ExternalPtrAccess::WRITE | ExternalPtrAccess::READ;
   }
 };
diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp
index 23b5064832454..471172d0869d0 100644
--- a/taichi/transforms/ir_printer.cpp
+++ b/taichi/transforms/ir_printer.cpp
@@ -528,14 +528,8 @@ class IRPrinter : public IRVisitor {
   }
 
   void visit(ExternalPtrStmt *stmt) override {
-    std::string s = "<";
-    for (int i = 0; i < (int)stmt->base_ptrs.size(); i++) {
-      s += fmt::format("{}", stmt->base_ptrs[i]->name());
-      if (i + 1 < (int)stmt->base_ptrs.size()) {
-        s += ", ";
-      }
-    }
-    s += ">, [";
+    std::string s = stmt->base_ptr->name();
+    s += ", [";
     for (int i = 0; i < (int)stmt->indices.size(); i++) {
       s += fmt::format("{}", stmt->indices[i]->name());
       if (i + 1 < (int)stmt->indices.size()) {
diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp
index d0ba541c0661a..de4ffefc2f6ed 100644
--- a/taichi/transforms/type_check.cpp
+++ b/taichi/transforms/type_check.cpp
@@ -409,8 +409,7 @@ class TypeCheck : public IRVisitor {
 
   void visit(ExternalPtrStmt *stmt) override {
     stmt->ret_type.set_is_pointer(true);
-    stmt->ret_type = TypeFactory::create_vector_or_scalar_type(
-        stmt->base_ptrs.size(), stmt->base_ptrs[0]->ret_type);
+    stmt->ret_type = stmt->base_ptr->ret_type;
     for (int i = 0; i < stmt->indices.size(); i++) {
       TI_ASSERT(is_integral(stmt->indices[i]->ret_type));
       if (stmt->indices[i]->ret_type != PrimitiveType::i32) {

From 77e8dd9775027b270f5491f3aff005cbc0bb2eea Mon Sep 17 00:00:00 2001
From: Yi Xu <xy_xuyi@foxmail.com>
Date: Mon, 29 Aug 2022 16:14:53 +0800
Subject: [PATCH 2/4] Handle GlobalPtrStmt

---
 taichi/analysis/alias_analysis.cpp            |  6 +-
 taichi/analysis/bls_analyzer.cpp              |  2 +-
 taichi/analysis/gather_snode_read_writes.cpp  | 10 +-
 .../gather_uniquely_accessed_pointers.cpp     | 45 ++++-----
 taichi/analysis/mesh_bls_analyzer.cpp         |  2 +-
 taichi/analysis/same_statements.cpp           |  5 +-
 taichi/ir/control_flow_graph.cpp              | 24 ++---
 taichi/ir/statements.cpp                      | 56 +----------
 taichi/ir/statements.h                        | 13 +--
 taichi/transforms/auto_diff.cpp               | 93 +++++++++----------
 taichi/transforms/bit_loop_vectorize.cpp      |  4 +-
 taichi/transforms/check_out_of_bound.cpp      |  3 +-
 taichi/transforms/demote_atomics.cpp          | 29 +++---
 taichi/transforms/flag_access.cpp             | 40 ++++----
 taichi/transforms/ir_printer.cpp              |  4 +-
 taichi/transforms/lower_access.cpp            |  7 +-
 taichi/transforms/make_block_local.cpp        |  2 +-
 taichi/transforms/make_mesh_block_local.cpp   |  2 +-
 taichi/transforms/make_thread_local.cpp       |  4 +-
 taichi/transforms/scalar_pointer_lowerer.cpp  |  2 +-
 taichi/transforms/type_check.cpp              | 18 ++--
 21 files changed, 142 insertions(+), 229 deletions(-)

diff --git a/taichi/analysis/alias_analysis.cpp b/taichi/analysis/alias_analysis.cpp
index 828fccad48355..3fba06827426c 100644
--- a/taichi/analysis/alias_analysis.cpp
+++ b/taichi/analysis/alias_analysis.cpp
@@ -120,7 +120,7 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
   // SNode::id.
   auto get_snode_id = [](Stmt *s) {
     if (auto ptr = s->cast<GlobalPtrStmt>()) {
-      return ptr->snodes[0]->id;
+      return ptr->snode->id;
     } else if (auto get_child = s->cast<GetChStmt>()) {
       return get_child->output_snode->id;
     }
@@ -137,8 +137,8 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
   if (var1->is<GlobalPtrStmt>() && var2->is<GlobalPtrStmt>()) {
     auto ptr1 = var1->as<GlobalPtrStmt>();
     auto ptr2 = var2->as<GlobalPtrStmt>();
-    auto snode = ptr1->snodes[0];
-    TI_ASSERT(snode == ptr2->snodes[0]);
+    auto snode = ptr1->snode;
+    TI_ASSERT(snode == ptr2->snode);
     TI_ASSERT(ptr1->indices.size() == ptr2->indices.size());
     bool uncertain = false;
     for (int i = 0; i < (int)ptr1->indices.size(); i++) {
diff --git a/taichi/analysis/bls_analyzer.cpp b/taichi/analysis/bls_analyzer.cpp
index 86eb56733615a..2db07a0e7e0a1 100644
--- a/taichi/analysis/bls_analyzer.cpp
+++ b/taichi/analysis/bls_analyzer.cpp
@@ -36,7 +36,7 @@ void BLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) {
   if (!stmt->is<GlobalPtrStmt>())
     return;  // local alloca
   auto ptr = stmt->as<GlobalPtrStmt>();
-  auto snode = ptr->snodes[0];
+  auto snode = ptr->snode;
   if (!pads_->has(snode)) {
     return;
   }
diff --git a/taichi/analysis/gather_snode_read_writes.cpp b/taichi/analysis/gather_snode_read_writes.cpp
index 45abb217ad84c..270b6a88d074d 100644
--- a/taichi/analysis/gather_snode_read_writes.cpp
+++ b/taichi/analysis/gather_snode_read_writes.cpp
@@ -28,12 +28,10 @@ gather_snode_read_writes(IRNode *root) {
     }
     if (ptr) {
       if (auto *global_ptr = ptr->cast<GlobalPtrStmt>()) {
-        for (auto &snode : global_ptr->snodes.data) {
-          if (read)
-            accessed.first.emplace(snode);
-          if (write)
-            accessed.second.emplace(snode);
-        }
+        if (read)
+          accessed.first.emplace(global_ptr->snode);
+        if (write)
+          accessed.second.emplace(global_ptr->snode);
       }
     }
     return false;
diff --git a/taichi/analysis/gather_uniquely_accessed_pointers.cpp b/taichi/analysis/gather_uniquely_accessed_pointers.cpp
index 2347b2b7a01bd..4c746e37a3c8f 100644
--- a/taichi/analysis/gather_uniquely_accessed_pointers.cpp
+++ b/taichi/analysis/gather_uniquely_accessed_pointers.cpp
@@ -185,6 +185,7 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
   }
 
   void visit(GlobalPtrStmt *stmt) override {
+    auto snode = stmt->snode;
     // mesh-for loop unique
     if (stmt->indices.size() == 1 &&
         stmt->indices[0]->is<MeshIndexConversionStmt>()) {
@@ -194,37 +195,31 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
         idx = idx->as<MeshIndexConversionStmt>()->idx;
       }
       if (idx->is<LoopIndexStmt>() &&
-          idx->as<LoopIndexStmt>()->is_mesh_index()) {  // from-end access
-        for (auto &snode : stmt->snodes.data) {
-          if (rel_access_pointer_.find(snode) ==
-              rel_access_pointer_.end()) {  // not accessed by neibhours yet
-            accessed_pointer_[snode] = stmt;
-          } else {  // accessed by neibhours, so it's not unique
-            accessed_pointer_[snode] = nullptr;
-          }
+        idx->as<LoopIndexStmt>()->is_mesh_index()) {  // from-end access
+        if (rel_access_pointer_.find(snode) ==
+            rel_access_pointer_.end()) {  // not accessed by neibhours yet
+          accessed_pointer_[snode] = stmt;
+        } else {  // accessed by neibhours, so it's not unique
+          accessed_pointer_[snode] = nullptr;
         }
       } else {  // to-end access
-        for (auto &snode : stmt->snodes.data) {
-          rel_access_pointer_[snode] = stmt;
-          accessed_pointer_[snode] =
-              nullptr;  // from-end access should not be unique
-        }
+        rel_access_pointer_[snode] = stmt;
+        accessed_pointer_[snode] =
+            nullptr;  // from-end access should not be unique
       }
     }
     // Range-for / struct-for
-    for (auto &snode : stmt->snodes.data) {
-      auto accessed_ptr = accessed_pointer_.find(snode);
-      if (accessed_ptr == accessed_pointer_.end()) {
-        if (loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt)) {
-          accessed_pointer_[snode] = stmt;
-        } else {
-          accessed_pointer_[snode] = nullptr;  // not loop-unique
-        }
+    auto accessed_ptr = accessed_pointer_.find(snode);
+    if (accessed_ptr == accessed_pointer_.end()) {
+      if (loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt)) {
+        accessed_pointer_[snode] = stmt;
       } else {
-        if (!irpass::analysis::definitely_same_address(accessed_ptr->second,
-                                                       stmt)) {
-          accessed_ptr->second = nullptr;  // not uniquely accessed
-        }
+        accessed_pointer_[snode] = nullptr;  // not loop-unique
+      }
+    } else {
+      if (!irpass::analysis::definitely_same_address(accessed_ptr->second,
+                                                     stmt)) {
+        accessed_ptr->second = nullptr;  // not uniquely accessed
       }
     }
   }
diff --git a/taichi/analysis/mesh_bls_analyzer.cpp b/taichi/analysis/mesh_bls_analyzer.cpp
index c64bedf310148..91ebc35b48da1 100644
--- a/taichi/analysis/mesh_bls_analyzer.cpp
+++ b/taichi/analysis/mesh_bls_analyzer.cpp
@@ -35,7 +35,7 @@ void MeshBLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) {
   auto idx = conv->idx;
   if (conv_type == mesh::ConvType::g2r)
     return;
-  auto snode = ptr->snodes[0];
+  auto snode = ptr->snode;
   if (!caches_->has(snode)) {
     if (auto_mesh_local_ &&
         (flag == AccessFlag::accumulate ||
diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp
index c1e132cb4d15d..9022e4cb9fd46 100644
--- a/taichi/analysis/same_statements.cpp
+++ b/taichi/analysis/same_statements.cpp
@@ -136,9 +136,8 @@ class IRNodeComparator : public IRVisitor {
         // And we cannot use irpass::analysis::definitely_same_address()
         // directly because that function does not support id_map.
 
-        // TODO: Update this part if GlobalPtrStmt comes to have more fields
-        if (stmt->as<GlobalPtrStmt>()->snodes[0]->id !=
-            other->as<GlobalPtrStmt>()->snodes[0]->id) {
+        if (stmt->as<GlobalPtrStmt>()->snode->id !=
+            other->as<GlobalPtrStmt>()->snode->id) {
           same = false;
           return;
         }
diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp
index 282c22f7092ea..371d4730a11f2 100644
--- a/taichi/ir/control_flow_graph.cpp
+++ b/taichi/ir/control_flow_graph.cpp
@@ -335,22 +335,14 @@ void CFGNode::gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const {
       if (auto global_ptr = load_ptr->cast<GlobalPtrStmt>()) {
         // Avoid computing the UD-chain if every SNode in this global ptr
         // are already loaded because it can be time-consuming.
-        bool already_loaded = true;
-        for (auto &snode : global_ptr->snodes.data) {
-          if (snodes.count(snode) == 0) {
-            already_loaded = false;
-            break;
-          }
-        }
-        if (already_loaded) {
+        auto snode = global_ptr->snode;
+        if (snodes.count(snode) > 0) {
           continue;
         }
         if (reach_in.find(global_ptr) != reach_in.end() &&
             !contain_variable(killed_in_this_node, global_ptr)) {
           // The UD-chain contains the value before this offloaded task.
-          for (auto &snode : global_ptr->snodes.data) {
-            snodes.insert(snode);
-          }
+          snodes.insert(snode);
         }
       }
     }
@@ -459,8 +451,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
           } else if (!is_parallel_executed ||
                      (atomic->dest->is<GlobalPtrStmt>() &&
                       atomic->dest->as<GlobalPtrStmt>()
-                          ->snodes[0]
-                          ->is_scalar())) {
+                          ->snode->is_scalar())) {
             // If this node is parallel executed, we can't weaken a global
             // atomic operation to a global load.
             // TODO: we can weaken it if it's element-wise (i.e. never
@@ -704,9 +695,8 @@ void ControlFlowGraph::live_variable_analysis(
     }
     if (auto *gptr = stmt->cast<GlobalPtrStmt>();
         gptr && config_opt.has_value()) {
-      TI_ASSERT(gptr->snodes.size() == 1);
       const bool res =
-          (config_opt->eliminable_snodes.count(gptr->snodes[0]) == 0);
+          (config_opt->eliminable_snodes.count(gptr->snode) == 0);
       return res;
     }
     // A global pointer that may be loaded after this kernel.
@@ -874,9 +864,7 @@ std::unordered_set<SNode *> ControlFlowGraph::gather_loaded_snodes() {
   // Therefore we include the nodes[final_node]->reach_in in snodes.
   for (auto &stmt : nodes[final_node]->reach_in) {
     if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
-      for (auto &snode : global_ptr->snodes.data) {
-        snodes.insert(snode);
-      }
+      snodes.insert(global_ptr->snode);
     }
   }
 
diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp
index 4cfb4fc808681..244a6816ac1f6 100644
--- a/taichi/ir/statements.cpp
+++ b/taichi/ir/statements.cpp
@@ -49,56 +49,16 @@ ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr,
   this->element_dim = element_dim;
 }
 
-GlobalPtrStmt::GlobalPtrStmt(const LaneAttribute<SNode *> &snodes,
+GlobalPtrStmt::GlobalPtrStmt(SNode *snode,
                              const std::vector<Stmt *> &indices,
                              bool activate)
-    : snodes(snodes),
-      indices(indices),
-      activate(activate),
+    : snode(snode), indices(indices), activate(activate),
       is_bit_vectorized(false) {
-  for (int i = 0; i < (int)snodes.size(); i++) {
-    TI_ASSERT(snodes[i] != nullptr);
-    TI_ASSERT(snodes[0]->dt == snodes[i]->dt);
-  }
-  TI_ASSERT(snodes.size() == 1);
-  element_type() = snodes[0]->dt;
+  TI_ASSERT(snode != nullptr);
+  element_type() = snode->dt;
   TI_STMT_REG_FIELDS;
 }
 
-bool GlobalPtrStmt::is_element_wise(const SNode *snode) const {
-  if (snode == nullptr) {
-    // check every SNode when "snode" is nullptr
-    for (const auto &snode_i : snodes.data) {
-      if (!is_element_wise(snode_i)) {
-        return false;
-      }
-    }
-    return true;
-  }
-  // check if this statement is element-wise on a specific SNode, i.e., argument
-  // "snode"
-  for (int i = 0; i < (int)indices.size(); i++) {
-    if (auto loop_index_i = indices[i]->cast<LoopIndexStmt>();
-        !(loop_index_i && loop_index_i->loop->is<OffloadedStmt>() &&
-          loop_index_i->index == snode->physical_index_position[i])) {
-      return false;
-    }
-  }
-  return true;
-}
-
-bool GlobalPtrStmt::covers_snode(const SNode *snode) const {
-  // Check if the addresses of this statement all over the loop cover
-  // all active indices of the snode.
-  for (auto &index : indices) {
-    if (auto loop_unique = index->cast<LoopUniqueStmt>()) {
-      if (loop_unique->covers_snode(snode))
-        return true;
-    }
-  }
-  return is_element_wise(snode);
-}
-
 PtrOffsetStmt::PtrOffsetStmt(Stmt *origin_input, Stmt *offset_input) {
   origin = origin_input;
   offset = offset_input;
@@ -163,14 +123,6 @@ LoopUniqueStmt::LoopUniqueStmt(Stmt *input, const std::vector<SNode *> &covers)
   TI_STMT_REG_FIELDS;
 }
 
-bool LoopUniqueStmt::covers_snode(const SNode *snode) const {
-  if (snode->is_place()) {
-    return covers.count(snode->parent->id) > 0;
-  } else {
-    TI_NOT_IMPLEMENTED
-  }
-}
-
 Stmt *LocalLoadStmt::previous_store_or_alloca_in_block() {
   int position = parent->locate(this);
   // TI_ASSERT(width() == 1);
diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h
index 3baf2e3ee39e4..0416a5168aba5 100644
--- a/taichi/ir/statements.h
+++ b/taichi/ir/statements.h
@@ -327,19 +327,14 @@ class ExternalPtrStmt : public Stmt {
  */
 class GlobalPtrStmt : public Stmt {
  public:
-  LaneAttribute<SNode *> snodes;
+  SNode *snode;
   std::vector<Stmt *> indices;
   bool activate;
   bool is_bit_vectorized;  // for bit_loop_vectorize pass
 
-  GlobalPtrStmt(const LaneAttribute<SNode *> &snodes,
-                const std::vector<Stmt *> &indices,
+  GlobalPtrStmt(SNode *snode, const std::vector<Stmt *> &indices,
                 bool activate = true);
 
-  bool is_element_wise(const SNode *snode) const;
-
-  bool covers_snode(const SNode *snode) const;
-
   bool has_global_side_effect() const override {
     return activate;
   }
@@ -348,7 +343,7 @@ class GlobalPtrStmt : public Stmt {
     return true;
   }
 
-  TI_STMT_DEF_FIELDS(ret_type, snodes, indices, activate, is_bit_vectorized);
+  TI_STMT_DEF_FIELDS(ret_type, snode, indices, activate, is_bit_vectorized);
   TI_DEFINE_ACCEPT_AND_CLONE
 };
 
@@ -537,8 +532,6 @@ class LoopUniqueStmt : public Stmt {
 
   LoopUniqueStmt(Stmt *input, const std::vector<SNode *> &covers);
 
-  bool covers_snode(const SNode *snode) const;
-
   bool has_global_side_effect() const override {
     return false;
   }
diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp
index 7d12db7d83c5f..2a91cc48dc41f 100644
--- a/taichi/transforms/auto_diff.cpp
+++ b/taichi/transforms/auto_diff.cpp
@@ -35,11 +35,8 @@ class IndependentBlocksJudger : public BasicStmtVisitor {
     if (is_inside_loop_)
       return;
     TI_ASSERT(stmt->dest->is<GlobalPtrStmt>());
-    for (const auto &node : stmt->dest->cast<GlobalPtrStmt>()->snodes.data) {
-      if (node->has_adjoint()) {
-        qualified_atomics_ = false;
-        break;
-      }
+    if (stmt->dest->as<GlobalPtrStmt>()->snode->has_adjoint()) {
+      qualified_atomics_ = false;
     }
   }
 
@@ -1001,18 +998,18 @@ class MakeAdjoint : public ADTransform {
       src = stmt->src->as<GlobalPtrStmt>();
     }
 
-    auto snodes = src->snodes;
-    if (!snodes[0]->has_adjoint()) {
+    auto snode = src->snode;
+    if (!snode->has_adjoint()) {
       // No adjoint SNode. Do nothing
       return;
     }
-    if (gradients_stopped(stmt, snodes[0])) {
+    if (gradients_stopped(stmt, snode)) {
       // gradients stopped, do nothing.
       return;
     }
-    TI_ASSERT(snodes[0]->get_adjoint() != nullptr);
-    snodes[0] = snodes[0]->get_adjoint();
-    auto adj_ptr = insert<GlobalPtrStmt>(snodes, src->indices);
+    TI_ASSERT(snode->get_adjoint() != nullptr);
+    snode = snode->get_adjoint();
+    auto adj_ptr = insert<GlobalPtrStmt>(snode, src->indices);
     if (is_ptr_offset) {
       adj_ptr = insert<PtrOffsetStmt>(adj_ptr,
                                       stmt->src->as<PtrOffsetStmt>()->offset);
@@ -1037,14 +1034,14 @@ class MakeAdjoint : public ADTransform {
       dest = stmt->dest->as<GlobalPtrStmt>();
     }
 
-    auto snodes = dest->snodes;
-    if (!snodes[0]->has_adjoint()) {
+    auto snode = dest->snode;
+    if (!snode->has_adjoint()) {
       // no gradient (likely integer types)
       return;
     }
-    TI_ASSERT(snodes[0]->get_adjoint() != nullptr);
-    snodes[0] = snodes[0]->get_adjoint();
-    auto adjoint_ptr = insert<GlobalPtrStmt>(snodes, dest->indices);
+    TI_ASSERT(snode->get_adjoint() != nullptr);
+    snode = snode->get_adjoint();
+    auto adjoint_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
     if (is_ptr_offset) {
       adjoint_ptr = insert<PtrOffsetStmt>(
           adjoint_ptr, stmt->dest->as<PtrOffsetStmt>()->offset);
@@ -1064,15 +1061,15 @@ class MakeAdjoint : public ADTransform {
       dest = stmt->dest->as<GlobalPtrStmt>();
     }
 
-    auto snodes = dest->snodes;
-    if (!snodes[0]->has_adjoint()) {
+    auto snode = dest->snode;
+    if (!snode->has_adjoint()) {
       // no gradient (likely integer types)
       return;
     }
 
-    TI_ASSERT(snodes[0]->get_adjoint() != nullptr);
-    snodes[0] = snodes[0]->get_adjoint();
-    auto adjoint_ptr = insert<GlobalPtrStmt>(snodes, dest->indices);
+    TI_ASSERT(snode->get_adjoint() != nullptr);
+    snode = snode->get_adjoint();
+    auto adjoint_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
     if (is_ptr_offset) {
       adjoint_ptr = insert<PtrOffsetStmt>(
           adjoint_ptr, stmt->dest->as<PtrOffsetStmt>()->offset);
@@ -1327,18 +1324,18 @@ class MakeDual : public ADTransform {
     } else {
       src = stmt->src->as<GlobalPtrStmt>();
     }
-    auto snodes = src->snodes;
-    if (!snodes[0]->has_dual()) {
+    auto snode = src->snode;
+    if (!snode->has_dual()) {
       // No dual SNode. Do nothing
       return;
     }
-    if (gradients_stopped(stmt, snodes[0])) {
+    if (gradients_stopped(stmt, snode)) {
       // gradients stopped, do nothing.
       return;
     }
-    TI_ASSERT(snodes[0]->get_dual() != nullptr);
-    snodes[0] = snodes[0]->get_dual();
-    auto dual_ptr = insert<GlobalPtrStmt>(snodes, src->indices);
+    TI_ASSERT(snode->get_dual() != nullptr);
+    snode = snode->get_dual();
+    auto dual_ptr = insert<GlobalPtrStmt>(snode, src->indices);
     if (is_ptr_offset) {
       dual_ptr = insert<PtrOffsetStmt>(dual_ptr,
                                        stmt->src->as<PtrOffsetStmt>()->offset);
@@ -1355,14 +1352,14 @@ class MakeDual : public ADTransform {
     } else {
       dest = stmt->dest->as<GlobalPtrStmt>();
     }
-    auto snodes = dest->snodes;
-    if (!snodes[0]->has_dual()) {
+    auto snode = dest->snode;
+    if (!snode->has_dual()) {
       // no gradient (likely integer types)
       return;
     }
-    TI_ASSERT(snodes[0]->get_dual() != nullptr);
-    snodes[0] = snodes[0]->get_dual();
-    auto dual_ptr = insert<GlobalPtrStmt>(snodes, dest->indices);
+    TI_ASSERT(snode->get_dual() != nullptr);
+    snode = snode->get_dual();
+    auto dual_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
     if (is_ptr_offset) {
       dual_ptr = insert<PtrOffsetStmt>(dual_ptr,
                                        stmt->dest->as<PtrOffsetStmt>()->offset);
@@ -1379,14 +1376,14 @@ class MakeDual : public ADTransform {
     } else {
       dest = stmt->dest->as<GlobalPtrStmt>();
     }
-    auto snodes = dest->snodes;
-    if (!snodes[0]->has_dual()) {
+    auto snode = dest->snode;
+    if (!snode->has_dual()) {
       // no gradient (likely integer types)
       return;
     }
-    TI_ASSERT(snodes[0]->get_dual() != nullptr);
-    snodes[0] = snodes[0]->get_dual();
-    auto dual_ptr = insert<GlobalPtrStmt>(snodes, dest->indices);
+    TI_ASSERT(snode->get_dual() != nullptr);
+    snode = snode->get_dual();
+    auto dual_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
     if (is_ptr_offset) {
       dual_ptr = insert<PtrOffsetStmt>(dual_ptr,
                                        stmt->dest->as<PtrOffsetStmt>()->offset);
@@ -1550,28 +1547,28 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor {
 
   void visit(GlobalLoadStmt *stmt) override {
     GlobalPtrStmt *src = stmt->src->as<GlobalPtrStmt>();
-    auto snodes = src->snodes;
-    if (!snodes[0]->has_adjoint_checkbit()) {
+    auto snode = src->snode;
+    if (!snode->has_adjoint_checkbit()) {
       return;
     }
-    TI_ASSERT(snodes[0]->get_adjoint_checkbit() != nullptr);
-    snodes[0] = snodes[0]->get_adjoint_checkbit();
+    TI_ASSERT(snode->get_adjoint_checkbit() != nullptr);
+    snode = snode->get_adjoint_checkbit();
     auto gloabl_ptr =
-        stmt->insert_after_me(Stmt::make<GlobalPtrStmt>(snodes, src->indices));
+        stmt->insert_after_me(Stmt::make<GlobalPtrStmt>(snode, src->indices));
     auto one = gloabl_ptr->insert_after_me(
         Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(1)));
     one->insert_after_me(Stmt::make<GlobalStoreStmt>(gloabl_ptr, one));
   }
 
   void visit_gloabl_store_stmt_and_atomic_add(Stmt *stmt, GlobalPtrStmt *dest) {
-    auto snodes = dest->snodes;
-    if (!snodes[0]->has_adjoint_checkbit()) {
+    auto snode = dest->snode;
+    if (!snode->has_adjoint_checkbit()) {
       return;
     }
-    TI_ASSERT(snodes[0]->get_adjoint_checkbit() != nullptr);
-    snodes[0] = snodes[0]->get_adjoint_checkbit();
+    TI_ASSERT(snode->get_adjoint_checkbit() != nullptr);
+    snode = snode->get_adjoint_checkbit();
     auto global_ptr = stmt->insert_before_me(
-        Stmt::make<GlobalPtrStmt>(snodes, dest->indices));
+        Stmt::make<GlobalPtrStmt>(snode, dest->indices));
     auto global_load =
         stmt->insert_before_me(Stmt::make<GlobalLoadStmt>(global_ptr));
     auto zero = stmt->insert_before_me(
@@ -1581,7 +1578,7 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor {
     std::string msg = fmt::format(
         "(kernel={}) Breaks the global data access rule. Snode {} is "
         "overwritten unexpectedly.",
-        kernel_name_, dest->snodes[0]->get_node_type_name());
+        kernel_name_, dest->snode->get_node_type_name());
     msg += "\n" + stmt->tb;
 
     stmt->insert_before_me(
diff --git a/taichi/transforms/bit_loop_vectorize.cpp b/taichi/transforms/bit_loop_vectorize.cpp
index 51a88fe4fdce7..d2448919d6d4a 100644
--- a/taichi/transforms/bit_loop_vectorize.cpp
+++ b/taichi/transforms/bit_loop_vectorize.cpp
@@ -61,7 +61,7 @@ class BitLoopVectorize : public IRVisitor {
             auto indices = ptr->indices;
             indices[1] = loop_stmt->body->statements[1].get();
             auto base_ptr =
-                std::make_unique<GlobalPtrStmt>(ptr->snodes, indices);
+                std::make_unique<GlobalPtrStmt>(ptr->snode, indices);
             base_ptr->ret_type = new_ret_type;
             base_ptr->is_bit_vectorized = true;
             // load x[i, j](base)
@@ -80,7 +80,7 @@ class BitLoopVectorize : public IRVisitor {
                 offset_index_opcode, indices[1], offset_constant.get());
             indices[1] = offset_index.get();
             auto offset_ptr =
-                std::make_unique<GlobalPtrStmt>(ptr->snodes, indices);
+                std::make_unique<GlobalPtrStmt>(ptr->snode, indices);
             offset_ptr->ret_type = new_ret_type;
             offset_ptr->is_bit_vectorized = true;
             auto load_offsetted =
diff --git a/taichi/transforms/check_out_of_bound.cpp b/taichi/transforms/check_out_of_bound.cpp
index 47e2105e954bb..32734bf55cb10 100644
--- a/taichi/transforms/check_out_of_bound.cpp
+++ b/taichi/transforms/check_out_of_bound.cpp
@@ -43,8 +43,7 @@ class CheckOutOfBound : public BasicStmtVisitor {
   void visit(GlobalPtrStmt *stmt) override {
     if (is_done(stmt))
       return;
-    TI_ASSERT(stmt->snodes.size() == 1);
-    auto snode = stmt->snodes[0];
+    auto snode = stmt->snode;
     bool has_offset = !(snode->index_offsets.empty());
     auto new_stmts = VecStatement();
     auto zero = new_stmts.push_back<ConstStmt>(LaneAttribute<TypedConstant>(0));
diff --git a/taichi/transforms/demote_atomics.cpp b/taichi/transforms/demote_atomics.cpp
index 0b3a3552c666f..7b6cf210690b8 100644
--- a/taichi/transforms/demote_atomics.cpp
+++ b/taichi/transforms/demote_atomics.cpp
@@ -46,21 +46,18 @@ class DemoteAtomics : public BasicStmtVisitor {
         if (stmt->dest->is<GlobalPtrStmt>()) {
           demote = true;
           auto dest = stmt->dest->as<GlobalPtrStmt>();
-          for (auto snode : dest->snodes.data) {
-            if (loop_unique_ptr_[snode] == nullptr ||
-                loop_unique_ptr_[snode]->indices.empty()) {
-              // not uniquely accessed
-              demote = false;
-              break;
-            }
-            if (current_offloaded->mem_access_opt.has_flag(
-                    snode, SNodeAccessFlag::block_local) ||
-                current_offloaded->mem_access_opt.has_flag(
-                    snode, SNodeAccessFlag::mesh_local)) {
-              // BLS does not support write access yet so we keep atomic_adds.
-              demote = false;
-              break;
-            }
+          auto snode = dest->snode;
+          if (loop_unique_ptr_[snode] == nullptr ||
+              loop_unique_ptr_[snode]->indices.empty()) {
+            // not uniquely accessed
+            demote = false;
+          }
+          if (current_offloaded->mem_access_opt.has_flag(
+                  snode, SNodeAccessFlag::block_local) ||
+              current_offloaded->mem_access_opt.has_flag(
+                  snode, SNodeAccessFlag::mesh_local)) {
+            // BLS does not support write access yet so we keep atomic_adds.
+            demote = false;
           }
           // demote from-end atomics
           if (current_offloaded->task_type == OffloadedTaskType::mesh_for) {
@@ -74,7 +71,7 @@ class DemoteAtomics : public BasicStmtVisitor {
               if (idx->is<LoopIndexStmt>() &&
                   idx->as<LoopIndexStmt>()->is_mesh_index() &&
                   loop_unique_ptr_[stmt->dest->as<GlobalPtrStmt>()
-                                       ->snodes.data[0]] != nullptr) {
+                                       ->snode] != nullptr) {
                 demote = true;
               }
             }
diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp
index f27e4de7cf961..bf754b518dc1e 100644
--- a/taichi/transforms/flag_access.cpp
+++ b/taichi/transforms/flag_access.cpp
@@ -133,30 +133,28 @@ class WeakenAccess : public BasicStmtVisitor {
           current_struct_for_;
       if (is_struct_for) {
         bool same_as_loop_snode = true;
-        for (auto snode : stmt->snodes.data) {
-          SNode *loop_snode = nullptr;
-          if (current_struct_for_) {
-            loop_snode = current_struct_for_->snode;
-          } else {
-            loop_snode = current_offload_->snode;
-          }
-          TI_ASSERT(loop_snode);
-          if (!share_sparsity(snode, loop_snode)) {
-            same_as_loop_snode = false;
-          }
-          if (stmt->indices.size() == loop_snode->num_active_indices)
-            for (int i = 0; i < loop_snode->num_active_indices; i++) {
-              auto ind = stmt->indices[i];
-              // TODO: vectorized cases?
-              if (auto loop_var = ind->cast<LoopIndexStmt>()) {
-                if (loop_var->index != i) {
-                  same_as_loop_snode = false;
-                }
-              } else {
+        SNode *loop_snode = nullptr;
+        if (current_struct_for_) {
+          loop_snode = current_struct_for_->snode;
+        } else {
+          loop_snode = current_offload_->snode;
+        }
+        TI_ASSERT(loop_snode);
+        if (!share_sparsity(stmt->snode, loop_snode)) {
+          same_as_loop_snode = false;
+        }
+        if (stmt->indices.size() == loop_snode->num_active_indices)
+          for (int i = 0; i < loop_snode->num_active_indices; i++) {
+            auto ind = stmt->indices[i];
+            // TODO: vectorized cases?
+            if (auto loop_var = ind->cast<LoopIndexStmt>()) {
+              if (loop_var->index != i) {
                 same_as_loop_snode = false;
               }
+            } else {
+              same_as_loop_snode = false;
             }
-        }
+          }
         if (same_as_loop_snode)
           stmt->activate = false;
       }
diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp
index 471172d0869d0..ffa5decd26e34 100644
--- a/taichi/transforms/ir_printer.cpp
+++ b/taichi/transforms/ir_printer.cpp
@@ -392,8 +392,8 @@ class IRPrinter : public IRVisitor {
         fmt::format("{}{} = global ptr [", stmt->type_hint(), stmt->name());
 
     std::string snode_name;
-    if (stmt->snodes[0]) {
-      snode_name = stmt->snodes[0]->get_node_type_name_hinted();
+    if (stmt->snode) {
+      snode_name = stmt->snode->get_node_type_name_hinted();
     } else {
       snode_name = "unknown";
     }
diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp
index c284cf440768b..6296574135c63 100644
--- a/taichi/transforms/lower_access.cpp
+++ b/taichi/transforms/lower_access.cpp
@@ -99,8 +99,7 @@ class LowerAccess : public IRVisitor {
       // For ti.is_active
       TI_ASSERT(!activate);
     }
-    PtrLowererImpl lowerer{ptr->snodes[0],         ptr->indices, snode_op,
-                           ptr->is_bit_vectorized, &lowered,     packed};
+    PtrLowererImpl lowerer{ptr->snode, ptr->indices, snode_op, ptr->is_bit_vectorized, &lowered, packed};
     lowerer.set_pointer_needs_activation(activate);
     lowerer.set_lower_access(this);
     lowerer.run();
@@ -109,12 +108,12 @@ class LowerAccess : public IRVisitor {
     if (ptr->is_bit_vectorized) {
       // if the global ptr is bit vectorized, we start from the place snode
       // and find the parent quant array snode, use its physical type
-      auto parent_ret_type = ptr->snodes[0]->parent->physical_type;
+      auto parent_ret_type = ptr->snode->parent->physical_type;
       auto ptr_ret_type =
           TypeFactory::get_instance().get_pointer_type(parent_ret_type);
       lowered_ptr->ret_type = DataType(ptr_ret_type);
     } else {
-      lowered_ptr->ret_type = ptr->snodes[0]->dt;
+      lowered_ptr->ret_type = ptr->snode->dt;
     }
     return lowered;
   }
diff --git a/taichi/transforms/make_block_local.cpp b/taichi/transforms/make_block_local.cpp
index 18aa2b2e7c7eb..eb00a89afd712 100644
--- a/taichi/transforms/make_block_local.cpp
+++ b/taichi/transforms/make_block_local.cpp
@@ -216,7 +216,7 @@ void make_block_local_offload(OffloadedStmt *offload,
       // TODO: no more abuse of gather_statements...
       irpass::analysis::gather_statements(offload->body.get(), [&](Stmt *stmt) {
         if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
-          if (global_ptr->snodes[0] == snode) {
+          if (global_ptr->snode == snode) {
             global_ptrs.push_back(global_ptr);
           }
         }
diff --git a/taichi/transforms/make_mesh_block_local.cpp b/taichi/transforms/make_mesh_block_local.cpp
index 5208ca55bb38b..c7f3d689ebadc 100644
--- a/taichi/transforms/make_mesh_block_local.cpp
+++ b/taichi/transforms/make_mesh_block_local.cpp
@@ -100,7 +100,7 @@ void MakeMeshBlockLocal::replace_global_ptrs(SNode *snode) {
   std::vector<GlobalPtrStmt *> global_ptrs;
   irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) {
     if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
-      if (global_ptr->snodes[0] == snode &&
+      if (global_ptr->snode == snode &&
           global_ptr->indices[0]->is<MeshIndexConversionStmt>()) {
         global_ptrs.push_back(global_ptr);
       }
diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp
index c6b371b7ea0ec..6c157456f6b46 100644
--- a/taichi/transforms/make_thread_local.cpp
+++ b/taichi/transforms/make_thread_local.cpp
@@ -106,9 +106,9 @@ void make_thread_local_offload(OffloadedStmt *offload) {
           // We can only optimized reductions to global ptrs with form like
           // loss[None] (0-D fields) for now.
           // No TLS on quant types.
-          return (dest->snodes[0]->type == SNodeType::place) &&
+          return (dest->snode->type == SNodeType::place) &&
                  dest->indices.empty() &&
-                 dest->snodes[0]->dt->is<PrimitiveType>();
+                 dest->snode->dt->is<PrimitiveType>();
         });
     auto valid_global_tmps =
         find_global_reduction_destinations<GlobalTemporaryStmt>(
diff --git a/taichi/transforms/scalar_pointer_lowerer.cpp b/taichi/transforms/scalar_pointer_lowerer.cpp
index 495506cb39cd6..b392bbbe565c3 100644
--- a/taichi/transforms/scalar_pointer_lowerer.cpp
+++ b/taichi/transforms/scalar_pointer_lowerer.cpp
@@ -37,7 +37,7 @@ void ScalarPointerLowerer::run() {
   // for a given SNode. It characterizes the relationship between a parent
   // and a child SNode: "parent.start = child.start + child.num_bits".
   //
-  // For example, if there are two 1D snodes a and b,
+  // For example, if there are two 1D snode a and b,
   // where a = ti.root.dense(ti.i, 2) and b = a.dense(ti.i, 8),
   // we have a.start = b.start + 3 for the i-th dimension.
   // When accessing b[15], then bits [0, 3) of 15 are for accessing b,
diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp
index de4ffefc2f6ed..d6a024471adb8 100644
--- a/taichi/transforms/type_check.cpp
+++ b/taichi/transforms/type_check.cpp
@@ -144,20 +144,18 @@ class TypeCheck : public IRVisitor {
       return;
     }
     stmt->ret_type.set_is_pointer(true);
-    if (stmt->snodes) {
+    if (stmt->snode) {
       stmt->ret_type =
-          TypeFactory::get_instance().get_pointer_type(stmt->snodes[0]->dt);
+          TypeFactory::get_instance().get_pointer_type(stmt->snode->dt);
     } else
       TI_WARN("[{}] Type inference failed: snode is nullptr.\n{}", stmt->name(),
               stmt->tb);
-    for (int l = 0; l < stmt->snodes.size(); l++) {
-      if (stmt->snodes[l]->parent->num_active_indices != 0 &&
-          stmt->snodes[l]->parent->num_active_indices != stmt->indices.size()) {
-        TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(),
-                 stmt->snodes[l]->parent->node_type_name,
-                 stmt->snodes[l]->parent->num_active_indices,
-                 stmt->indices.size());
-      }
+    if (stmt->snode->parent->num_active_indices != 0 &&
+        stmt->snode->parent->num_active_indices != stmt->indices.size()) {
+      TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(),
+               stmt->snode->parent->node_type_name,
+               stmt->snode->parent->num_active_indices,
+               stmt->indices.size());
     }
     for (int i = 0; i < stmt->indices.size(); i++) {
       if (!is_integral(stmt->indices[i]->ret_type)) {

From 03c56118fe10cc2ebc22847901359ce6afdcf94d Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Mon, 29 Aug 2022 08:18:24 +0000
Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 taichi/analysis/gather_uniquely_accessed_pointers.cpp | 2 +-
 taichi/ir/control_flow_graph.cpp                      | 6 ++----
 taichi/ir/statements.cpp                              | 4 +++-
 taichi/ir/statements.h                                | 9 ++++++---
 taichi/transforms/auto_diff.cpp                       | 4 ++--
 taichi/transforms/demote_atomics.cpp                  | 4 ++--
 taichi/transforms/lower_access.cpp                    | 4 +++-
 taichi/transforms/make_thread_local.cpp               | 3 +--
 taichi/transforms/type_check.cpp                      | 3 +--
 9 files changed, 21 insertions(+), 18 deletions(-)

diff --git a/taichi/analysis/gather_uniquely_accessed_pointers.cpp b/taichi/analysis/gather_uniquely_accessed_pointers.cpp
index 4c746e37a3c8f..e2cfd21af6f30 100644
--- a/taichi/analysis/gather_uniquely_accessed_pointers.cpp
+++ b/taichi/analysis/gather_uniquely_accessed_pointers.cpp
@@ -195,7 +195,7 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
         idx = idx->as<MeshIndexConversionStmt>()->idx;
       }
       if (idx->is<LoopIndexStmt>() &&
-        idx->as<LoopIndexStmt>()->is_mesh_index()) {  // from-end access
+          idx->as<LoopIndexStmt>()->is_mesh_index()) {  // from-end access
         if (rel_access_pointer_.find(snode) ==
             rel_access_pointer_.end()) {  // not accessed by neibhours yet
           accessed_pointer_[snode] = stmt;
diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp
index 371d4730a11f2..1a6949fa5f44d 100644
--- a/taichi/ir/control_flow_graph.cpp
+++ b/taichi/ir/control_flow_graph.cpp
@@ -450,8 +450,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
             continue;
           } else if (!is_parallel_executed ||
                      (atomic->dest->is<GlobalPtrStmt>() &&
-                      atomic->dest->as<GlobalPtrStmt>()
-                          ->snode->is_scalar())) {
+                      atomic->dest->as<GlobalPtrStmt>()->snode->is_scalar())) {
             // If this node is parallel executed, we can't weaken a global
             // atomic operation to a global load.
             // TODO: we can weaken it if it's element-wise (i.e. never
@@ -695,8 +694,7 @@ void ControlFlowGraph::live_variable_analysis(
     }
     if (auto *gptr = stmt->cast<GlobalPtrStmt>();
         gptr && config_opt.has_value()) {
-      const bool res =
-          (config_opt->eliminable_snodes.count(gptr->snode) == 0);
+      const bool res = (config_opt->eliminable_snodes.count(gptr->snode) == 0);
       return res;
     }
     // A global pointer that may be loaded after this kernel.
diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp
index 244a6816ac1f6..4909b40f4bf9e 100644
--- a/taichi/ir/statements.cpp
+++ b/taichi/ir/statements.cpp
@@ -52,7 +52,9 @@ ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr,
 GlobalPtrStmt::GlobalPtrStmt(SNode *snode,
                              const std::vector<Stmt *> &indices,
                              bool activate)
-    : snode(snode), indices(indices), activate(activate),
+    : snode(snode),
+      indices(indices),
+      activate(activate),
       is_bit_vectorized(false) {
   TI_ASSERT(snode != nullptr);
   element_type() = snode->dt;
diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h
index 0416a5168aba5..1bb042cf78846 100644
--- a/taichi/ir/statements.h
+++ b/taichi/ir/statements.h
@@ -306,8 +306,10 @@ class ExternalPtrStmt : public Stmt {
 
   ExternalPtrStmt(Stmt *base_ptr, const std::vector<Stmt *> &indices);
 
-  ExternalPtrStmt(Stmt *base_ptr, const std::vector<Stmt *> &indices,
-                  const std::vector<int> &element_shape, int element_dim);
+  ExternalPtrStmt(Stmt *base_ptr,
+                  const std::vector<Stmt *> &indices,
+                  const std::vector<int> &element_shape,
+                  int element_dim);
 
   bool has_global_side_effect() const override {
     return false;
@@ -332,7 +334,8 @@ class GlobalPtrStmt : public Stmt {
   bool activate;
   bool is_bit_vectorized;  // for bit_loop_vectorize pass
 
-  GlobalPtrStmt(SNode *snode, const std::vector<Stmt *> &indices,
+  GlobalPtrStmt(SNode *snode,
+                const std::vector<Stmt *> &indices,
                 bool activate = true);
 
   bool has_global_side_effect() const override {
diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp
index 2a91cc48dc41f..d58cd7aed24b5 100644
--- a/taichi/transforms/auto_diff.cpp
+++ b/taichi/transforms/auto_diff.cpp
@@ -1567,8 +1567,8 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor {
     }
     TI_ASSERT(snode->get_adjoint_checkbit() != nullptr);
     snode = snode->get_adjoint_checkbit();
-    auto global_ptr = stmt->insert_before_me(
-        Stmt::make<GlobalPtrStmt>(snode, dest->indices));
+    auto global_ptr =
+        stmt->insert_before_me(Stmt::make<GlobalPtrStmt>(snode, dest->indices));
     auto global_load =
         stmt->insert_before_me(Stmt::make<GlobalLoadStmt>(global_ptr));
     auto zero = stmt->insert_before_me(
diff --git a/taichi/transforms/demote_atomics.cpp b/taichi/transforms/demote_atomics.cpp
index 7b6cf210690b8..14b143e7b27d3 100644
--- a/taichi/transforms/demote_atomics.cpp
+++ b/taichi/transforms/demote_atomics.cpp
@@ -70,8 +70,8 @@ class DemoteAtomics : public BasicStmtVisitor {
               }
               if (idx->is<LoopIndexStmt>() &&
                   idx->as<LoopIndexStmt>()->is_mesh_index() &&
-                  loop_unique_ptr_[stmt->dest->as<GlobalPtrStmt>()
-                                       ->snode] != nullptr) {
+                  loop_unique_ptr_[stmt->dest->as<GlobalPtrStmt>()->snode] !=
+                      nullptr) {
                 demote = true;
               }
             }
diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp
index 6296574135c63..1597fb68157b5 100644
--- a/taichi/transforms/lower_access.cpp
+++ b/taichi/transforms/lower_access.cpp
@@ -99,7 +99,9 @@ class LowerAccess : public IRVisitor {
       // For ti.is_active
       TI_ASSERT(!activate);
     }
-    PtrLowererImpl lowerer{ptr->snode, ptr->indices, snode_op, ptr->is_bit_vectorized, &lowered, packed};
+    PtrLowererImpl lowerer{ptr->snode, ptr->indices,
+                           snode_op,   ptr->is_bit_vectorized,
+                           &lowered,   packed};
     lowerer.set_pointer_needs_activation(activate);
     lowerer.set_lower_access(this);
     lowerer.run();
diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp
index 6c157456f6b46..d6e1ca852138d 100644
--- a/taichi/transforms/make_thread_local.cpp
+++ b/taichi/transforms/make_thread_local.cpp
@@ -107,8 +107,7 @@ void make_thread_local_offload(OffloadedStmt *offload) {
           // loss[None] (0-D fields) for now.
           // No TLS on quant types.
           return (dest->snode->type == SNodeType::place) &&
-                 dest->indices.empty() &&
-                 dest->snode->dt->is<PrimitiveType>();
+                 dest->indices.empty() && dest->snode->dt->is<PrimitiveType>();
         });
     auto valid_global_tmps =
         find_global_reduction_destinations<GlobalTemporaryStmt>(
diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp
index d6a024471adb8..78b4868e94747 100644
--- a/taichi/transforms/type_check.cpp
+++ b/taichi/transforms/type_check.cpp
@@ -154,8 +154,7 @@ class TypeCheck : public IRVisitor {
         stmt->snode->parent->num_active_indices != stmt->indices.size()) {
       TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(),
                stmt->snode->parent->node_type_name,
-               stmt->snode->parent->num_active_indices,
-               stmt->indices.size());
+               stmt->snode->parent->num_active_indices, stmt->indices.size());
     }
     for (int i = 0; i < stmt->indices.size(); i++) {
       if (!is_integral(stmt->indices[i]->ret_type)) {

From 6bb64119897ab848f1ef63080af68289e98b99b9 Mon Sep 17 00:00:00 2001
From: Yi Xu <xy_xuyi@foxmail.com>
Date: Mon, 29 Aug 2022 16:24:20 +0800
Subject: [PATCH 4/4] Cancel accidental change

---
 taichi/transforms/scalar_pointer_lowerer.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/taichi/transforms/scalar_pointer_lowerer.cpp b/taichi/transforms/scalar_pointer_lowerer.cpp
index b392bbbe565c3..495506cb39cd6 100644
--- a/taichi/transforms/scalar_pointer_lowerer.cpp
+++ b/taichi/transforms/scalar_pointer_lowerer.cpp
@@ -37,7 +37,7 @@ void ScalarPointerLowerer::run() {
   // for a given SNode. It characterizes the relationship between a parent
   // and a child SNode: "parent.start = child.start + child.num_bits".
   //
-  // For example, if there are two 1D snode a and b,
+  // For example, if there are two 1D snodes a and b,
   // where a = ti.root.dense(ti.i, 2) and b = a.dense(ti.i, 8),
   // we have a.start = b.start + 3 for the i-th dimension.
   // When accessing b[15], then bits [0, 3) of 15 are for accessing b,