diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index be49ab30c2b93..685bfe42d7178 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -236,7 +236,6 @@ def subscript(value, *_indices, skip_reordered=False, get_ref=False): assert current_cfg().real_matrix is True assert is_tensor(value.ptr.get_ret_type()) - # TODO(zhanlue): Merge _ti_core.subscript and _ti_core.make_index_expr return Expr( _ti_core.subscript(value.ptr, indices_expr_group, get_runtime().get_current_src_info())) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index b506d4354accc..47171cfb69697 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -12,6 +12,10 @@ TLANG_NAMESPACE_BEGIN "[{}] was not type-checked", \ ExpressionHumanFriendlyPrinter::expr_to_string(x)) +static bool is_primitive_or_tensor_type(DataType &type) { + return type->is() || type->is(); +} + FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, SNode *snode, const ExprGroup &indices, @@ -180,8 +184,17 @@ void BinaryOpExpression::type_check(CompileConfig *config) { binary_op_type_symbol(type), lhs->ret_type->to_string(), rhs->ret_type->to_string())); }; - if (!lhs_type->is() || !rhs_type->is()) + + if (!is_primitive_or_tensor_type(lhs_type) || + !is_primitive_or_tensor_type(rhs_type)) { error(); + } + + if ((lhs_type->is() && rhs_type->is()) || + (lhs_type->is() && rhs_type->is())) { + TI_NOT_IMPLEMENTED; + } + if (binary_is_bitwise(type) && (!is_integral(lhs_type) || !is_integral(rhs_type))) error(); diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 8fe19312d16ff..c1d1abd1d1878 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -65,8 +65,10 @@ PtrOffsetStmt::PtrOffsetStmt(Stmt *origin_input, Stmt *offset_input) { origin = origin_input; offset = offset_input; if (origin->is()) { - TI_ASSERT(origin->cast()->ret_type->is()); - auto tensor_type = origin->cast()->ret_type->cast(); + TI_ASSERT( + origin->cast()->ret_type.ptr_removed()->is()); + auto tensor_type = + origin->cast()->ret_type.ptr_removed()->cast(); element_type() = tensor_type->get_element_type(); element_type().set_is_pointer(true); } else if (origin->is()) { @@ -78,9 +80,12 @@ PtrOffsetStmt::PtrOffsetStmt(Stmt *origin_input, Stmt *offset_input) { } else if (origin->is()) { element_type() = origin->cast()->ret_type; } else if (origin->is()) { - TI_ASSERT(origin->cast()->ret_type->is()); - auto tensor_type = - origin->cast()->ret_type->cast(); + TI_ASSERT(origin->cast() + ->ret_type.ptr_removed() + ->is()); + auto tensor_type = origin->cast() + ->ret_type.ptr_removed() + ->cast(); element_type() = tensor_type->get_element_type(); element_type().set_is_pointer(true); } else { diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index ac6006646bc3a..6d10d15974234 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -29,6 +29,7 @@ namespace irpass { void re_id(IRNode *root); void flag_access(IRNode *root); +void scalarize(IRNode *root); bool die(IRNode *root); bool simplify(IRNode *root, const CompileConfig &config); bool cfg_optimization( diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index 78cb1480495bc..a3e022e47bd93 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -49,6 +49,7 @@ CompileConfig::CompileConfig() { ndarray_use_cached_allocator = true; use_mesh = false; real_matrix = false; + real_matrix_scalarize = false; saturating_grid_dim = 0; max_block_dim = 0; diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index 14f17e419fd6d..a908662129cbd 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -45,6 +45,7 @@ struct CompileConfig { bool ndarray_use_cached_allocator; bool use_mesh; bool real_matrix; + bool real_matrix_scalarize; DataType default_fp; DataType default_ip; DataType default_up; diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 505ea107ff1f5..d5d5e8ab385aa 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -52,6 +52,11 @@ void compile_to_offloads(IRNode *ir, print("Lowered"); } + if (config.real_matrix && config.real_matrix_scalarize) { + irpass::scalarize(ir); + print("Scalarized"); + } + irpass::type_check(ir, config); print("Typechecked"); irpass::analysis::verify(ir); @@ -316,6 +321,12 @@ void compile_function(IRNode *ir, irpass::lower_ast(ir); print("Lowered"); } + + if (config.real_matrix && config.real_matrix_scalarize) { + irpass::scalarize(ir); + print("Scalarized"); + } + irpass::lower_access(ir, config, {{}, true}); print("Access lowered"); irpass::analysis::verify(ir); diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp new file mode 100644 index 0000000000000..e140a3bf0e665 --- /dev/null +++ b/taichi/transforms/scalarize.cpp @@ -0,0 +1,126 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" +#include "taichi/system/profiler.h" + +TLANG_NAMESPACE_BEGIN + +class Scalarize : public IRVisitor { + public: + Scalarize(IRNode *node) { + allow_undefined_visitor = true; + invoke_default_visitor = false; + node->accept(this); + } + + /* + "val" of StoreStmt should have already been replaced by a MatrixInitStmt in + former scalarization. + + Before: + StoreStmt(TensorType<4 x i32>* dest, TensorType<4 x i32> val) + + After: + addr0 = PtrOffsetStmt(TensorType<4 x i32>* dest, 0) + addr1 = PtrOffsetStmt(TensorType<4 x i32>* dest, 1) + addr2 = PtrOffsetStmt(TensorType<4 x i32>* dest, 2) + addr2 = PtrOffsetStmt(TensorType<4 x i32>* dest, 3) + + StoreStmt(i32* addr0, i32 val->cast()->val[0]) + StoreStmt(i32* addr1, i32 val->cast()->val[1]) + StoreStmt(i32* addr2, i32 val->cast()->val[2]) + StoreStmt(i32* addr3, i32 val->cast()->val[3]) + */ + template + void scalarize_store_stmt(T *stmt) { + auto dest_dtype = stmt->dest->ret_type.ptr_removed(); + auto val_dtype = stmt->val->ret_type; + if (dest_dtype->template is() && + val_dtype->template is()) { + // Needs scalarize + auto dest_tensor_type = dest_dtype->template as(); + auto val_tensor_type = val_dtype->template as(); + TI_ASSERT(dest_tensor_type->get_shape() == val_tensor_type->get_shape()); + TI_ASSERT(dest_tensor_type->get_element_type() == + val_tensor_type->get_element_type()); + + TI_ASSERT(stmt->val->template is()); + auto matrix_init_stmt = stmt->val->template as(); + + int num_elements = val_tensor_type->get_num_elements(); + for (int i = 0; i < num_elements; i++) { + auto const_stmt = std::make_unique( + TypedConstant(stmt->val->ret_type.get_element_type(), i)); + + auto ptr_offset_stmt = + std::make_unique(stmt->dest, const_stmt.get()); + auto scalarized_stmt = std::make_unique(ptr_offset_stmt.get(), + matrix_init_stmt->values[i]); + + stmt->insert_before_me(std::move(const_stmt)); + stmt->insert_before_me(std::move(ptr_offset_stmt)); + stmt->insert_before_me(std::move(scalarized_stmt)); + } + stmt->parent->erase(stmt); + } + } + + void visit(Block *stmt_list) override { + for (auto &stmt : stmt_list->statements) { + stmt->accept(this); + } + } + + void visit(IfStmt *if_stmt) override { + if (if_stmt->true_statements) + if_stmt->true_statements->accept(this); + if (if_stmt->false_statements) { + if_stmt->false_statements->accept(this); + } + } + + void visit(WhileStmt *stmt) override { + stmt->body->accept(this); + } + + void visit(RangeForStmt *for_stmt) override { + for_stmt->body->accept(this); + } + + void visit(StructForStmt *for_stmt) override { + for_stmt->body->accept(this); + } + + void visit(MeshForStmt *for_stmt) override { + for_stmt->body->accept(this); + } + + void visit(OffloadedStmt *stmt) override { + stmt->all_blocks_accept(this); + } + + void visit(GlobalStoreStmt *stmt) override { + scalarize_store_stmt(stmt); + } + + void visit(LocalStoreStmt *stmt) override { + scalarize_store_stmt(stmt); + } +}; + +namespace irpass { + +void scalarize(IRNode *root) { + TI_AUTO_PROF; + Scalarize scalarize_pass(root); + + /* TODO(zhanlue): Remove redundant MatrixInitStmt + Scalarize pass will generate temporary MatrixInitStmts, which are only used + as rvalues. Remove these MatrixInitStmts since it's no longer needed. + */ +} + +} // namespace irpass + +TLANG_NAMESPACE_END diff --git a/tests/cpp/transforms/scalarize_test.cpp b/tests/cpp/transforms/scalarize_test.cpp new file mode 100644 index 0000000000000..749fb055e3d98 --- /dev/null +++ b/tests/cpp/transforms/scalarize_test.cpp @@ -0,0 +1,85 @@ +#include "gtest/gtest.h" + +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "tests/cpp/program/test_program.h" + +namespace taichi { +namespace lang { + +// Basic tests within a basic block +template +void test_scalarize() { + TestProgram test_prog; + test_prog.setup(); + + auto block = std::make_unique(); + + auto func = []() {}; + auto kernel = + std::make_unique(*test_prog.prog(), func, "fake_kernel"); + block->kernel = kernel.get(); + + auto &type_factory = TypeFactory::get_instance(); + + /* + TensorType<4 x i32>* %1 = ExternalPtrStmt() + TensorType<4 x i32> %2 = MatrixInitStmt([1, 1, 2, 2]) + StoreStmt(%1, %2) + */ + Type *tensor_type = type_factory.get_tensor_type( + {2, 2}, type_factory.get_primitive_type(PrimitiveTypeID::i32)); + auto const_1_stmt = block->push_back(TypedConstant(1)); + auto const_2_stmt = block->push_back(TypedConstant(2)); + auto argload_stmt = block->push_back(0 /*arg_id*/, tensor_type); + + Stmt *dest_stmt = nullptr; + if (std::is_same::value) { + std::vector indices = {}; + dest_stmt = block->push_back( + argload_stmt, indices); // fake ExternalPtrStmt + + } else { + dest_stmt = block->push_back(tensor_type); + } + dest_stmt->ret_type = type_factory.get_pointer_type(tensor_type); + + std::vector matrix_init_vals = {const_1_stmt, const_1_stmt, + const_2_stmt, const_2_stmt}; + auto matrix_init_stmt = + block->push_back(std::move(matrix_init_vals)); + matrix_init_stmt->ret_type = tensor_type; + + block->push_back(dest_stmt, matrix_init_stmt); + + irpass::scalarize(block.get()); + + EXPECT_EQ(block->size(), 2 /*const*/ + 1 /*argload*/ + 1 /*external_ptr*/ + + 1 /*matrix_init*/ + 4 /*const*/ + + 4 /*ptroffset*/ + 4 /*store*/); + + // Check for scalarized statements + EXPECT_EQ(block->statements[5]->is(), true); + EXPECT_EQ(block->statements[6]->is(), true); + EXPECT_EQ(block->statements[7]->is(), true); + + EXPECT_EQ(block->statements[8]->is(), true); + EXPECT_EQ(block->statements[9]->is(), true); + EXPECT_EQ(block->statements[10]->is(), true); + + EXPECT_EQ(block->statements[11]->is(), true); + EXPECT_EQ(block->statements[12]->is(), true); + EXPECT_EQ(block->statements[13]->is(), true); + + EXPECT_EQ(block->statements[14]->is(), true); + EXPECT_EQ(block->statements[15]->is(), true); + EXPECT_EQ(block->statements[16]->is(), true); +} + +TEST(Scalarize, ScalarizeStore) { + test_scalarize(); + test_scalarize(); +} + +} // namespace lang +} // namespace taichi