From 9af3f1f93df5f33d770b4f1282acb6abe8b7f6dd Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Wed, 28 Sep 2022 17:27:25 +0800 Subject: [PATCH 1/2] [Lang] Allow dynamic index of matrix field when real_matrix=True --- taichi/transforms/lower_matrix_ptr.cpp | 27 +++++++++++++++++++------- tests/python/test_matrix.py | 19 ++++++++++++++---- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/taichi/transforms/lower_matrix_ptr.cpp b/taichi/transforms/lower_matrix_ptr.cpp index 69f3721e72181..fee343b4ae9cd 100644 --- a/taichi/transforms/lower_matrix_ptr.cpp +++ b/taichi/transforms/lower_matrix_ptr.cpp @@ -13,14 +13,27 @@ class LowerMatrixPtr : public BasicStmtVisitor { void visit(MatrixPtrStmt *stmt) override { if (stmt->origin->is()) { - TI_ASSERT(stmt->offset->is()); auto origin = stmt->origin->as(); - auto offset = stmt->offset->as(); - auto lowered = std::make_unique( - origin->snodes[offset->val.val_int()], origin->indices); - stmt->replace_usages_with(lowered.get()); - modifier.insert_before(stmt, std::move(lowered)); - modifier.erase(stmt); + if (stmt->offset->is()) { + auto offset = stmt->offset->as(); + auto lowered = std::make_unique( + origin->snodes[offset->val.val_int()], origin->indices); + stmt->replace_usages_with(lowered.get()); + modifier.insert_before(stmt, std::move(lowered)); + modifier.erase(stmt); + } else { + TI_ASSERT_INFO(origin->dynamic_indexable, "Element of the MatrixField is not dynamic indexable.\n{}", stmt->tb); + auto stride = std::make_unique(TypedConstant(origin->dynamic_index_stride)); + auto offset = std::make_unique(BinaryOpType::mul, stmt->offset, stride.get()); + auto ptr_base = std::make_unique(origin->snodes[0], origin->indices); + auto lowered = std::make_unique(ptr_base.get(), offset.get()); + stmt->replace_usages_with(lowered.get()); + modifier.insert_before(stmt, std::move(stride)); + modifier.insert_before(stmt, std::move(offset)); + modifier.insert_before(stmt, std::move(ptr_base)); + modifier.insert_before(stmt, std::move(lowered)); + modifier.erase(stmt); + } } } diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 74ddb3f706617..3cf75032ba5cc 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -193,10 +193,7 @@ def func2(b: ti.types.ndarray(element_dim=1)): assert v[3][9] == 9 -@test_utils.test(require=ti.extension.dynamic_index, - dynamic_index=True, - debug=True) -def test_matrix_non_constant_index(): +def _test_matrix_non_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5) v = ti.Vector.field(10, ti.i32, 5) @@ -248,6 +245,20 @@ def func4(k: ti.i32): func4(10) +@test_utils.test(require=ti.extension.dynamic_index, + dynamic_index=True, + debug=True) +def test_matrix_non_constant_index(): + _test_matrix_non_constant_index() + + +@test_utils.test(require=ti.extension.dynamic_index, + real_matrix=True, + debug=True) +def test_matrix_non_constant_index_real_matrix(): + _test_matrix_non_constant_index() + + def _test_matrix_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5) From 415c89a117053d6b40f429cc4093ddb571e93d7d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Sep 2022 09:32:13 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/transforms/lower_matrix_ptr.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/taichi/transforms/lower_matrix_ptr.cpp b/taichi/transforms/lower_matrix_ptr.cpp index fee343b4ae9cd..de385457a572e 100644 --- a/taichi/transforms/lower_matrix_ptr.cpp +++ b/taichi/transforms/lower_matrix_ptr.cpp @@ -22,11 +22,18 @@ class LowerMatrixPtr : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(lowered)); modifier.erase(stmt); } else { - TI_ASSERT_INFO(origin->dynamic_indexable, "Element of the MatrixField is not dynamic indexable.\n{}", stmt->tb); - auto stride = std::make_unique(TypedConstant(origin->dynamic_index_stride)); - auto offset = std::make_unique(BinaryOpType::mul, stmt->offset, stride.get()); - auto ptr_base = std::make_unique(origin->snodes[0], origin->indices); - auto lowered = std::make_unique(ptr_base.get(), offset.get()); + TI_ASSERT_INFO( + origin->dynamic_indexable, + "Element of the MatrixField is not dynamic indexable.\n{}", + stmt->tb); + auto stride = std::make_unique( + TypedConstant(origin->dynamic_index_stride)); + auto offset = std::make_unique( + BinaryOpType::mul, stmt->offset, stride.get()); + auto ptr_base = + std::make_unique(origin->snodes[0], origin->indices); + auto lowered = + std::make_unique(ptr_base.get(), offset.get()); stmt->replace_usages_with(lowered.get()); modifier.insert_before(stmt, std::move(stride)); modifier.insert_before(stmt, std::move(offset));