diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 07600d02bfe32..f72456a664f3d 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -15,8 +15,9 @@ from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.field import Field +from taichi.lang.impl import current_cfg from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl, - _TiScopeMatrixImpl) + _TiScopeMatrixImpl, make_matrix) from taichi.lang.snode import append from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type from taichi.types import (annotations, ndarray_type, primitive_types, @@ -114,6 +115,12 @@ def build_Assign(ctx, node): @staticmethod def build_assign_slice(ctx, node_target, values, is_static_assign): target = ASTTransformer.build_Subscript(ctx, node_target, get_ref=True) + if current_cfg().real_matrix: + if isinstance(node_target.value.ptr, + any_array.AnyArray) and isinstance( + values, (list, tuple)): + values = make_matrix(values) + if isinstance(node_target.value.ptr, Matrix): if isinstance(node_target.value.ptr._impl, _TiScopeMatrixImpl): target._assign(values) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 0bca4fe175704..e91a17d197137 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -16,7 +16,7 @@ from taichi.lang.kernel_arguments import SparseMatrixProxy from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType, Vector, _IntermediateMatrix, - _MatrixFieldElement, make_matrix) + _MatrixFieldElement) from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance, MeshRelationAccessProxy, MeshReorderedMatrixFieldProxy, diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 3207e0d887834..2237b3be60c63 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -66,6 +66,7 @@ static std::vector get_offline_cache_key_of_compile_config( serializer(config->experimental_auto_mesh_local); serializer(config->auto_mesh_local_default_occupacy); serializer(config->real_matrix); + serializer(config->real_matrix_scalarize); serializer.finalize(); return serializer.data; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index c3d58b362e52e..73a0ca345bc4c 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -203,6 +203,8 @@ void export_lang(py::module &m) { &CompileConfig::ndarray_use_cached_allocator) .def_readwrite("use_mesh", &CompileConfig::use_mesh) .def_readwrite("real_matrix", &CompileConfig::real_matrix) + .def_readwrite("real_matrix_scalarize", + &CompileConfig::real_matrix_scalarize) .def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd) .def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd) .def_readwrite("quant_opt_store_fusion", diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index e140a3bf0e665..70ca1cd551e49 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -8,10 +8,14 @@ TLANG_NAMESPACE_BEGIN class Scalarize : public IRVisitor { public: + DelayedIRModifier modifier_; + Scalarize(IRNode *node) { allow_undefined_visitor = true; invoke_default_visitor = false; node->accept(this); + + modifier_.modify_ir(); } /* @@ -51,18 +55,75 @@ class Scalarize : public IRVisitor { int num_elements = val_tensor_type->get_num_elements(); for (int i = 0; i < num_elements; i++) { auto const_stmt = std::make_unique( - TypedConstant(stmt->val->ret_type.get_element_type(), i)); + TypedConstant(get_data_type(), i)); auto ptr_offset_stmt = std::make_unique(stmt->dest, const_stmt.get()); auto scalarized_stmt = std::make_unique(ptr_offset_stmt.get(), matrix_init_stmt->values[i]); - stmt->insert_before_me(std::move(const_stmt)); - stmt->insert_before_me(std::move(ptr_offset_stmt)); - stmt->insert_before_me(std::move(scalarized_stmt)); + modifier_.insert_before(stmt, std::move(const_stmt)); + modifier_.insert_before(stmt, std::move(ptr_offset_stmt)); + modifier_.insert_before(stmt, std::move(scalarized_stmt)); + } + modifier_.erase(stmt); + } + } + + /* + + Before: + TensorType<4 x i32> val = LoadStmt(TensorType<4 x i32>* src) + + After: + i32* addr0 = PtrOffsetStmt(TensorType<4 x i32>* src, 0) + i32* addr1 = PtrOffsetStmt(TensorType<4 x i32>* src, 1) + i32* addr2 = PtrOffsetStmt(TensorType<4 x i32>* src, 2) + i32* addr3 = PtrOffsetStmt(TensorType<4 x i32>* src, 3) + + i32 val0 = LoadStmt(addr0) + i32 val1 = LoadStmt(addr1) + i32 val2 = LoadStmt(addr2) + i32 val3 = LoadStmt(addr3) + + tmp = MatrixInitStmt(val0, val1, val2, val3) + + stmt->replace_all_usages_with(tmp) + */ + template + void scalarize_load_stmt(T *stmt) { + auto src_dtype = stmt->src->ret_type.ptr_removed(); + if (src_dtype->template is()) { + // Needs scalarize + auto src_tensor_type = src_dtype->template as(); + + std::vector matrix_init_values; + int num_elements = src_tensor_type->get_num_elements(); + + for (size_t i = 0; i < num_elements; i++) { + auto const_stmt = std::make_unique( + TypedConstant(get_data_type(), i)); + + auto ptr_offset_stmt = + std::make_unique(stmt->src, const_stmt.get()); + auto scalarized_stmt = std::make_unique(ptr_offset_stmt.get()); + + matrix_init_values.push_back(scalarized_stmt.get()); + + modifier_.insert_before(stmt, std::move(const_stmt)); + modifier_.insert_before(stmt, std::move(ptr_offset_stmt)); + modifier_.insert_before(stmt, std::move(scalarized_stmt)); } - stmt->parent->erase(stmt); + + auto matrix_init_stmt = + std::make_unique(matrix_init_values); + + matrix_init_stmt->ret_type = src_dtype; + + stmt->replace_usages_with(matrix_init_stmt.get()); + modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + + modifier_.erase(stmt); } } @@ -107,6 +168,14 @@ class Scalarize : public IRVisitor { void visit(LocalStoreStmt *stmt) override { scalarize_store_stmt(stmt); } + + void visit(GlobalLoadStmt *stmt) override { + scalarize_load_stmt(stmt); + } + + void visit(LocalLoadStmt *stmt) override { + scalarize_load_stmt(stmt); + } }; namespace irpass { diff --git a/tests/cpp/transforms/scalarize_test.cpp b/tests/cpp/transforms/scalarize_test.cpp index 749fb055e3d98..ab373492f6c0d 100644 --- a/tests/cpp/transforms/scalarize_test.cpp +++ b/tests/cpp/transforms/scalarize_test.cpp @@ -9,7 +9,7 @@ namespace lang { // Basic tests within a basic block template -void test_scalarize() { +void test_store_scalarize() { TestProgram test_prog; test_prog.setup(); @@ -76,9 +76,69 @@ void test_scalarize() { EXPECT_EQ(block->statements[16]->is(), true); } +template +void test_load_scalarize() { + TestProgram test_prog; + test_prog.setup(); + + auto block = std::make_unique(); + + auto func = []() {}; + auto kernel = + std::make_unique(*test_prog.prog(), func, "fake_kernel"); + block->kernel = kernel.get(); + + auto &type_factory = TypeFactory::get_instance(); + + /* + TensorType<4 x i32>* %1 = ExternalPtrStmt() + TensorType<4 x i32> %2 = LoadStmt(%1) + */ + Type *tensor_type = type_factory.get_tensor_type( + {2, 2}, type_factory.get_primitive_type(PrimitiveTypeID::i32)); + auto argload_stmt = block->push_back(0 /*arg_id*/, tensor_type); + + std::vector indices = {}; + Stmt *src_stmt = block->push_back( + argload_stmt, indices); // fake ExternalPtrStmt + src_stmt->ret_type = type_factory.get_pointer_type(tensor_type); + + block->push_back(src_stmt); + + irpass::scalarize(block.get()); + + EXPECT_EQ(block->size(), 1 /*argload*/ + 1 /*external_ptr*/ + 4 /*const*/ + + 4 /*ptroffset*/ + 4 /*load*/ + + 1 /*matrix_init*/); + + // Check for scalarized statements + EXPECT_EQ(block->statements[2]->is(), true); + EXPECT_EQ(block->statements[3]->is(), true); + EXPECT_EQ(block->statements[4]->is(), true); + + EXPECT_EQ(block->statements[5]->is(), true); + EXPECT_EQ(block->statements[6]->is(), true); + EXPECT_EQ(block->statements[7]->is(), true); + + EXPECT_EQ(block->statements[8]->is(), true); + EXPECT_EQ(block->statements[9]->is(), true); + EXPECT_EQ(block->statements[10]->is(), true); + + EXPECT_EQ(block->statements[11]->is(), true); + EXPECT_EQ(block->statements[12]->is(), true); + EXPECT_EQ(block->statements[13]->is(), true); + + EXPECT_EQ(block->statements[14]->is(), true); +} + TEST(Scalarize, ScalarizeStore) { - test_scalarize(); - test_scalarize(); + test_store_scalarize(); + test_store_scalarize(); +} + +TEST(Scalarize, ScalarizeLoad) { + test_load_scalarize(); + test_load_scalarize(); } } // namespace lang diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index a79746f8f4045..715e59c9210ad 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -817,3 +817,41 @@ def foo() -> ti.types.matrix(2, 2, ti.f32): return a @ b.transpose() assert foo() == [[1.0, 2.0], [2.0, 4.0]] + + +@test_utils.test(arch=[ti.cuda, ti.cpu], + real_matrix=True, + real_matrix_scalarize=True) +def test_store_scalarize(): + @ti.kernel + def func(a: ti.types.ndarray()): + for i in range(5): + a[i] = [[i, i + 1], [i + 2, i + 3]] + + x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5) + func(x) + + assert (x[0] == [[0, 1], [2, 3]]) + assert (x[1] == [[1, 2], [3, 4]]) + assert (x[2] == [[2, 3], [4, 5]]) + assert (x[3] == [[3, 4], [5, 6]]) + assert (x[4] == [[4, 5], [6, 7]]) + + +@test_utils.test(arch=[ti.cuda, ti.cpu], + real_matrix=True, + real_matrix_scalarize=True) +def test_load_store_scalarize(): + @ti.kernel + def func(a: ti.types.ndarray()): + for i in range(3): + a[i] = [[i, i + 1], [i + 2, i + 3]] + + a[3] = a[1] + a[4] = a[2] + + x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5) + func(x) + + assert (x[3] == [[1, 2], [3, 4]]) + assert (x[4] == [[2, 3], [4, 5]])