diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 88d94757665ce..3e141bd4f4986 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -810,7 +810,10 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { flatten_rvalue(indices[i], ctx); indices_stmt.push_back(indices[i]->stmt); } - auto ptr = ctx->push_back(snode, indices_stmt); + auto is_cell_access = SNodeOpStmt::activation_related(op_type) && + snode->type != SNodeType::dynamic; + auto ptr = + ctx->push_back(snode, indices_stmt, true, is_cell_access); ptr->tb = tb; if (op_type == SNodeOpType::is_active) { TI_ERROR_IF(snode->type != SNodeType::pointer && diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 7a262d4232d7f..1cc717934ae31 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -51,10 +51,12 @@ ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr, GlobalPtrStmt::GlobalPtrStmt(SNode *snode, const std::vector &indices, - bool activate) + bool activate, + bool is_cell_access) : snode(snode), indices(indices), activate(activate), + is_cell_access(is_cell_access), is_bit_vectorized(false) { TI_ASSERT(snode != nullptr); element_type() = snode->dt; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 4a9fd13112d47..887dc7fbacd1b 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -345,11 +345,13 @@ class GlobalPtrStmt : public Stmt { SNode *snode; std::vector indices; bool activate; + bool is_cell_access; bool is_bit_vectorized; // for bit_loop_vectorize pass GlobalPtrStmt(SNode *snode, const std::vector &indices, - bool activate = true); + bool activate = true, + bool is_cell_access = false); bool has_global_side_effect() const override { return activate; diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index 8ce05f73484a5..4b3db32fa555d 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -153,13 +153,12 @@ class LowerAccess : public IRVisitor { void visit(SNodeOpStmt *stmt) override { if (stmt->ptr->is()) { - if (SNodeOpStmt::activation_related(stmt->op_type) && - stmt->snode->type != SNodeType::dynamic) { - auto lowered = - lower_ptr(stmt->ptr->as(), false, stmt->op_type); + auto global_ptr = stmt->ptr->as(); + if (global_ptr->is_cell_access) { + auto lowered = lower_ptr(global_ptr, false, stmt->op_type); modifier.replace_with(stmt, std::move(lowered), true); } else if (stmt->op_type == SNodeOpType::get_addr) { - auto lowered = lower_ptr(stmt->ptr->as(), false); + auto lowered = lower_ptr(global_ptr, false); auto cast = lowered.push_back(UnaryOpType::cast_bits, lowered.back().get()); cast->cast_type = TypeFactory::get_instance().get_primitive_type( @@ -167,8 +166,8 @@ class LowerAccess : public IRVisitor { stmt->ptr = lowered.back().get(); modifier.replace_with(stmt, std::move(lowered)); } else { - auto lowered = lower_ptr(stmt->ptr->as(), - SNodeOpStmt::need_activation(stmt->op_type)); + auto lowered = + lower_ptr(global_ptr, SNodeOpStmt::need_activation(stmt->op_type)); stmt->ptr = lowered.back().get(); modifier.insert_before(stmt, std::move(lowered)); } diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index d3703e102dfee..358fed688eb93 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -433,7 +433,8 @@ class LowerAST : public IRVisitor { stmt->snode->type == SNodeType::dense || stmt->snode->type == SNodeType::bitmasked) { TI_ASSERT(SNodeOpStmt::activation_related(stmt->op_type)); - auto ptr = fctx.push_back(stmt->snode, indices_stmt); + auto ptr = + fctx.push_back(stmt->snode, indices_stmt, true, true); fctx.push_back(stmt->op_type, stmt->snode, ptr, val_stmt); } else { TI_ERROR("The {} operation is not supported on {} SNode", diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 79d2daeee4ff7..cbe3a34b2324f 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -145,12 +145,14 @@ class TypeCheck : public IRVisitor { } else TI_WARN("[{}] Type inference failed: snode is nullptr.\n{}", stmt->name(), stmt->tb); - if (stmt->snode->parent->num_active_indices != 0 && - stmt->snode->parent->num_active_indices != stmt->indices.size()) { - TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(), - stmt->snode->parent->node_type_name, - stmt->snode->parent->num_active_indices, stmt->indices.size()); - } + auto check_indices = [&](SNode *snode) { + if (snode->num_active_indices != stmt->indices.size()) { + TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(), + snode->node_type_name, snode->num_active_indices, + stmt->indices.size()); + } + }; + check_indices(stmt->is_cell_access ? stmt->snode : stmt->snode->parent); for (int i = 0; i < stmt->indices.size(); i++) { if (!stmt->indices[i]->ret_type->is_primitive(PrimitiveTypeID::i32)) { TI_WARN( diff --git a/tests/python/test_sparse_basics.py b/tests/python/test_sparse_basics.py index 118f3bfbbfd12..c7e6e5e20782e 100644 --- a/tests/python/test_sparse_basics.py +++ b/tests/python/test_sparse_basics.py @@ -50,6 +50,29 @@ def func(): assert s[None] == 256 +@test_utils.test(require=ti.extension.sparse) +def test_pointer_is_active_2(): + x = ti.field(ti.f32) + s = ti.field(ti.i32) + + n = 128 + + ti.root.dense(ti.i, n).pointer(ti.j, n).place(x) + ti.root.place(s) + + @ti.kernel + def func(): + for i, j in ti.ndrange(n, n): + s[None] += ti.is_active(x.parent(), [i, j]) + + x[0, 0] = 1 + x[0, 127] = 1 + x[127, 127] = 1 + + func() + assert s[None] == 3 + + def _test_pointer2(): x = ti.field(ti.f32) s = ti.field(ti.i32)