Skip to content

Commit

Permalink
[Bug] [lang] Make dimension check for GlobalPtrStmt aware of whether …
Browse files Browse the repository at this point in the history
…it is a cell access (#6275)

Issue: fix #6274

### Brief Summary

`GlobalPtrStmt` actually has two semantics:
- When it is part of an activation-related `SNodeOpStmt`
(`ti.activate(), ti.deactivate(), ti.is_active()`) and the SNode is not
of type `dynamic`, the indices are actually targeting a SNode **cell**;
- Otherwise, the indices are targeting the SNode **container**.

The original dimension check in the `type_check()` pass only targets the
second case, which results in the false alarm in #6274.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Oct 11, 2022
1 parent 4e1a733 commit 7e2ddaf
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 17 deletions.
5 changes: 4 additions & 1 deletion taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,10 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) {
flatten_rvalue(indices[i], ctx);
indices_stmt.push_back(indices[i]->stmt);
}
auto ptr = ctx->push_back<GlobalPtrStmt>(snode, indices_stmt);
auto is_cell_access = SNodeOpStmt::activation_related(op_type) &&
snode->type != SNodeType::dynamic;
auto ptr =
ctx->push_back<GlobalPtrStmt>(snode, indices_stmt, true, is_cell_access);
ptr->tb = tb;
if (op_type == SNodeOpType::is_active) {
TI_ERROR_IF(snode->type != SNodeType::pointer &&
Expand Down
4 changes: 3 additions & 1 deletion taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr,

GlobalPtrStmt::GlobalPtrStmt(SNode *snode,
const std::vector<Stmt *> &indices,
bool activate)
bool activate,
bool is_cell_access)
: snode(snode),
indices(indices),
activate(activate),
is_cell_access(is_cell_access),
is_bit_vectorized(false) {
TI_ASSERT(snode != nullptr);
element_type() = snode->dt;
Expand Down
4 changes: 3 additions & 1 deletion taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,13 @@ class GlobalPtrStmt : public Stmt {
SNode *snode;
std::vector<Stmt *> indices;
bool activate;
bool is_cell_access;
bool is_bit_vectorized; // for bit_loop_vectorize pass

GlobalPtrStmt(SNode *snode,
const std::vector<Stmt *> &indices,
bool activate = true);
bool activate = true,
bool is_cell_access = false);

bool has_global_side_effect() const override {
return activate;
Expand Down
13 changes: 6 additions & 7 deletions taichi/transforms/lower_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,22 +153,21 @@ class LowerAccess : public IRVisitor {

void visit(SNodeOpStmt *stmt) override {
if (stmt->ptr->is<GlobalPtrStmt>()) {
if (SNodeOpStmt::activation_related(stmt->op_type) &&
stmt->snode->type != SNodeType::dynamic) {
auto lowered =
lower_ptr(stmt->ptr->as<GlobalPtrStmt>(), false, stmt->op_type);
auto global_ptr = stmt->ptr->as<GlobalPtrStmt>();
if (global_ptr->is_cell_access) {
auto lowered = lower_ptr(global_ptr, false, stmt->op_type);
modifier.replace_with(stmt, std::move(lowered), true);
} else if (stmt->op_type == SNodeOpType::get_addr) {
auto lowered = lower_ptr(stmt->ptr->as<GlobalPtrStmt>(), false);
auto lowered = lower_ptr(global_ptr, false);
auto cast = lowered.push_back<UnaryOpStmt>(UnaryOpType::cast_bits,
lowered.back().get());
cast->cast_type = TypeFactory::get_instance().get_primitive_type(
PrimitiveTypeID::u64);
stmt->ptr = lowered.back().get();
modifier.replace_with(stmt, std::move(lowered));
} else {
auto lowered = lower_ptr(stmt->ptr->as<GlobalPtrStmt>(),
SNodeOpStmt::need_activation(stmt->op_type));
auto lowered =
lower_ptr(global_ptr, SNodeOpStmt::need_activation(stmt->op_type));
stmt->ptr = lowered.back().get();
modifier.insert_before(stmt, std::move(lowered));
}
Expand Down
3 changes: 2 additions & 1 deletion taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ class LowerAST : public IRVisitor {
stmt->snode->type == SNodeType::dense ||
stmt->snode->type == SNodeType::bitmasked) {
TI_ASSERT(SNodeOpStmt::activation_related(stmt->op_type));
auto ptr = fctx.push_back<GlobalPtrStmt>(stmt->snode, indices_stmt);
auto ptr =
fctx.push_back<GlobalPtrStmt>(stmt->snode, indices_stmt, true, true);
fctx.push_back<SNodeOpStmt>(stmt->op_type, stmt->snode, ptr, val_stmt);
} else {
TI_ERROR("The {} operation is not supported on {} SNode",
Expand Down
14 changes: 8 additions & 6 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,14 @@ class TypeCheck : public IRVisitor {
} else
TI_WARN("[{}] Type inference failed: snode is nullptr.\n{}", stmt->name(),
stmt->tb);
if (stmt->snode->parent->num_active_indices != 0 &&
stmt->snode->parent->num_active_indices != stmt->indices.size()) {
TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(),
stmt->snode->parent->node_type_name,
stmt->snode->parent->num_active_indices, stmt->indices.size());
}
auto check_indices = [&](SNode *snode) {
if (snode->num_active_indices != stmt->indices.size()) {
TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(),
snode->node_type_name, snode->num_active_indices,
stmt->indices.size());
}
};
check_indices(stmt->is_cell_access ? stmt->snode : stmt->snode->parent);
for (int i = 0; i < stmt->indices.size(); i++) {
if (!stmt->indices[i]->ret_type->is_primitive(PrimitiveTypeID::i32)) {
TI_WARN(
Expand Down
23 changes: 23 additions & 0 deletions tests/python/test_sparse_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,29 @@ def func():
assert s[None] == 256


@test_utils.test(require=ti.extension.sparse)
def test_pointer_is_active_2():
x = ti.field(ti.f32)
s = ti.field(ti.i32)

n = 128

ti.root.dense(ti.i, n).pointer(ti.j, n).place(x)
ti.root.place(s)

@ti.kernel
def func():
for i, j in ti.ndrange(n, n):
s[None] += ti.is_active(x.parent(), [i, j])

x[0, 0] = 1
x[0, 127] = 1
x[127, 127] = 1

func()
assert s[None] == 3


def _test_pointer2():
x = ti.field(ti.f32)
s = ti.field(ti.i32)
Expand Down

0 comments on commit 7e2ddaf

Please sign in to comment.