From 1034efff531aa851dd64b43c7dd888bfd828269d Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Thu, 13 Oct 2022 14:20:31 +0800 Subject: [PATCH 1/6] [Lang] MatrixNdarray refactor part12: Add scalarization for AtomicOpStmt --- python/taichi/lang/impl.py | 1 + taichi/ir/frontend_ir.cpp | 29 ++++++++++-- taichi/transforms/scalarize.cpp | 84 +++++++++++++++++++++++++++++++++ tests/python/test_matrix.py | 30 ++++++++++++ 4 files changed, 141 insertions(+), 3 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index b4f6b392384d5..08a81bcbde77f 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -52,6 +52,7 @@ def expr_init_shared_array(shape, element_type): @taichi_scope def expr_init(rhs): + print(11111) if rhs is None: return Expr(get_runtime().prog.current_ast_builder().expr_alloca()) if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")): diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 53125dd2c71e7..66470c7d10ec5 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -751,7 +751,7 @@ void IdExpression::flatten(FlattenContext *ctx) { } } -void AtomicOpExpression::type_check(CompileConfig *) { +void AtomicOpExpression::type_check(CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(dest); TI_ASSERT_TYPE_CHECKED(val); auto error = [&]() { @@ -760,11 +760,34 @@ void AtomicOpExpression::type_check(CompileConfig *) { atomic_op_type_name(op_type), dest->ret_type->to_string(), val->ret_type->to_string())); }; - if (!val->ret_type->is()) + + // Broadcast val to dest if neccessary + auto val_dtype = val->ret_type; + auto dest_dtype = dest->ret_type.ptr_removed(); + if (dest_dtype->is() and val_dtype->is()) { error(); + } + + if (val_dtype->is() and dest_dtype->is()) { + auto broadcasted_expr = to_broadcast_tensor(val, dest_dtype); + val = std::move(broadcasted_expr); + val.type_check(config); + } + + // Validate dtype + auto dtype = val->ret_type; + if (config->real_matrix) { + dtype = dtype.get_element_type(); + } + + if (!dtype->is()) { + error(); + } + if (is_quant(dest->ret_type)) { ret_type = dest->ret_type->get_compute_type(); - } else if (dest->ret_type->is()) { + } else if (dest->ret_type->is() || + dest->ret_type->is()) { ret_type = dest->ret_type; } else { error(); diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 7632011b79e83..620a6b6081293 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -258,6 +258,90 @@ class Scalarize : public BasicStmtVisitor { } } + /* + Before: + TensorType<4 x i32> val = AtomicStmt(TensorType<4 x i32>* dest, + TensorType<4 x i32> val) + + After: + i32* dest_ptr_0 = MatrixPtrStmt(dest, 0) + i32* dest_ptr_1 = MatrixPtrStmt(dest, 1) + i32* dest_ptr_2 = MatrixPtrStmt(dest, 2) + i32* dest_ptr_3 = MatrixPtrStmt(dest, 3) + + i32 dest_val0 = AtomicStmt(dest_ptr_0, + val->cast()->val[0]) + i32 dest_val1 = AtomicStmt(dest_ptr_1, + val->cast()->val[1]) + i32 dest_val2 = AtomicStmt(dest_ptr_2, + val->cast()->val[2]) + i32 dest_val3 = AtomicStmt(dest_ptr_3, + val->cast()->val[3]) + + tmp = MatrixInitStmt(dest_val0, dest_val1, + dest_val2, dest_val3) + + stmt->replace_all_usages_with(tmp) + */ + void visit(AtomicOpStmt *stmt) override { + auto dest_dtype = stmt->dest->ret_type.ptr_removed(); + auto val_dtype = stmt->val->ret_type; + + if (dest_dtype->is() && val_dtype->is()) { + return; + } + + // AtomicOpExpression::type_check() should have taken care of the + // broadcasting and neccessary conversions. So we simply add an assertion + // here to make sure that the operands are of the same shape and dtype + if (dest_dtype->is() && val_dtype->is()) { + // Scalarization for LoadStmt should have already replaced val operand + // to MatrixInitStmt + TI_ASSERT(stmt->val->is()); + + auto val_matrix_init_stmt = stmt->val->cast(); + std::vector val_values = val_matrix_init_stmt->values; + + size_t num_elements = val_values.size(); + auto primitive_type = stmt->ret_type.get_element_type(); + + // Scalarize dest & val + std::vector matrix_init_values; + for (size_t i = 0; i < num_elements; i++) { + // scalarize to dest_i + auto const_stmt = std::make_unique( + TypedConstant(get_data_type(), i)); + auto matrix_ptr_stmt = + std::make_unique(stmt->dest, const_stmt.get()); + matrix_ptr_stmt->ret_type = primitive_type; + matrix_ptr_stmt->ret_type.set_is_pointer(true); + + // scalarize to val_i + auto val_stmt = val_values[i]; + + // assemble to scalarized atomic_op + auto atomic_stmt = std::make_unique( + stmt->op_type, matrix_ptr_stmt.get(), val_stmt); + atomic_stmt->ret_type = primitive_type; + + matrix_init_values.push_back(atomic_stmt.get()); + + modifier_.insert_before(stmt, std::move(const_stmt)); + modifier_.insert_before(stmt, std::move(matrix_ptr_stmt)); + modifier_.insert_before(stmt, std::move(atomic_stmt)); + } + + auto matrix_init_stmt = + std::make_unique(matrix_init_values); + matrix_init_stmt->ret_type = stmt->ret_type; + + stmt->replace_usages_with(matrix_init_stmt.get()); + modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + + modifier_.erase(stmt); + } + } + void visit(GlobalStoreStmt *stmt) override { scalarize_store_stmt(stmt); } diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 9e55c1e24797f..61504fc9a304f 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -1044,3 +1044,33 @@ def verify(x): field = ti.Matrix.field(2, 2, ti.f32, shape=5) ndarray = ti.Matrix.ndarray(2, 2, ti.f32, shape=5) _test_field_and_ndarray(field, ndarray, func, verify) + + +@test_utils.test(arch=[ti.cuda, ti.cpu], + real_matrix=True, + real_matrix_scalarize=True, + debug=True) +def test_atomic_op_scalarize(): + @ti.func + def func(x: ti.template()): + x[0] = [1., 2., 3.] + tmp = ti.Vector([3., 2., 1.]) + z = ti.atomic_add(x[0], tmp) + assert z[0] == 1. + assert z[1] == 2. + assert z[2] == 3. + + # Broadcasting + x[1] = [1., 1., 1.] + g = ti.atomic_add(x[1], 2) + assert g[0] == 1. + assert g[1] == 1. + assert g[2] == 1. + + def verify(x): + assert (x[0] == [4., 4., 4.]).all() + assert (x[1] == [3., 3., 3.]).all() + + field = ti.Vector.field(n=3, dtype=ti.f32, shape=10) + ndarray = ti.Vector.ndarray(n=3, dtype=ti.f32, shape=(10)) + _test_field_and_ndarray(field, ndarray, func, verify) From 269176d9b1dbbb4e94c46b0444c1a64a726f0142 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Thu, 13 Oct 2022 14:22:35 +0800 Subject: [PATCH 2/6] Remove debug info --- python/taichi/lang/impl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 08a81bcbde77f..b4f6b392384d5 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -52,7 +52,6 @@ def expr_init_shared_array(shape, element_type): @taichi_scope def expr_init(rhs): - print(11111) if rhs is None: return Expr(get_runtime().prog.current_ast_builder().expr_alloca()) if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")): From 7d25905b361eabb6f518ff612adefbff8396048c Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Thu, 13 Oct 2022 16:52:11 +0800 Subject: [PATCH 3/6] Bug fix --- taichi/ir/frontend_ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 66470c7d10ec5..1dbf2cef62369 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -776,7 +776,7 @@ void AtomicOpExpression::type_check(CompileConfig *config) { // Validate dtype auto dtype = val->ret_type; - if (config->real_matrix) { + if (dtype->is()) { dtype = dtype.get_element_type(); } From 9d6a65f7330acc59df0aaaae7ec0c953cca20674 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 14 Oct 2022 18:22:59 +0800 Subject: [PATCH 4/6] Update taichi/ir/frontend_ir.cpp Co-authored-by: Yi Xu --- taichi/ir/frontend_ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 1dbf2cef62369..cbd6e80c2c909 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -764,7 +764,7 @@ void AtomicOpExpression::type_check(CompileConfig *config) { // Broadcast val to dest if neccessary auto val_dtype = val->ret_type; auto dest_dtype = dest->ret_type.ptr_removed(); - if (dest_dtype->is() and val_dtype->is()) { + if (dest_dtype->is() && val_dtype->is()) { error(); } From a202052b01bb4167e6360026a249421db5d6b34c Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 14 Oct 2022 18:23:04 +0800 Subject: [PATCH 5/6] Update taichi/ir/frontend_ir.cpp Co-authored-by: Yi Xu --- taichi/ir/frontend_ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index cbd6e80c2c909..e95d8cf0990a7 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -768,7 +768,7 @@ void AtomicOpExpression::type_check(CompileConfig *config) { error(); } - if (val_dtype->is() and dest_dtype->is()) { + if (val_dtype->is() && dest_dtype->is()) { auto broadcasted_expr = to_broadcast_tensor(val, dest_dtype); val = std::move(broadcasted_expr); val.type_check(config); From 6076c7483897af451d7598e945535b1e54efc61e Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 17 Oct 2022 09:49:18 +0800 Subject: [PATCH 6/6] Add comments --- taichi/transforms/scalarize.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 620a6b6081293..e0ee6cdf9da61 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -291,9 +291,13 @@ class Scalarize : public BasicStmtVisitor { return; } - // AtomicOpExpression::type_check() should have taken care of the - // broadcasting and neccessary conversions. So we simply add an assertion - // here to make sure that the operands are of the same shape and dtype + // AtomicOpExpression::type_check() have taken care of the broadcasting, + // but the type conversions are delayed until irpass::type_check(). + // So we only check for the shape here. + TI_ASSERT(dest_dtype->is() && val_dtype->is()); + TI_ASSERT(dest_dtype->cast()->get_shape() == + val_dtype->cast()->get_shape()); + if (dest_dtype->is() && val_dtype->is()) { // Scalarization for LoadStmt should have already replaced val operand // to MatrixInitStmt @@ -313,8 +317,6 @@ class Scalarize : public BasicStmtVisitor { TypedConstant(get_data_type(), i)); auto matrix_ptr_stmt = std::make_unique(stmt->dest, const_stmt.get()); - matrix_ptr_stmt->ret_type = primitive_type; - matrix_ptr_stmt->ret_type.set_is_pointer(true); // scalarize to val_i auto val_stmt = val_values[i];