From 2f9643e73c2ed71502fe68219fd2b0c64cf93623 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Wed, 21 Sep 2022 11:46:40 +0800 Subject: [PATCH 1/4] [Lang] MatrixNdarray refactor part8: Add scalarization for BinaryOpStmt with TensorType-operands --- python/taichi/lang/matrix.py | 3 ++ taichi/ir/frontend_ir.cpp | 48 ++++++++++++------ taichi/transforms/scalarize.cpp | 88 ++++++++++++++++++++++++++++++++- tests/python/test_matrix.py | 20 ++++++++ 4 files changed, 144 insertions(+), 15 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index c1e148ef00584..cde81a277c532 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -502,6 +502,9 @@ def __init__(self, def _element_wise_binary(self, foo, other): other = self._broadcast_copy(other) + if impl.current_cfg().real_matrix: + return foo(self, other) + if is_col_vector(self): return Vector([foo(self(i), other(i)) for i in range(self.n)], ndim=self.ndim) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 893908bd3f3a9..6b1e665cfa0e7 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -203,6 +203,7 @@ void BinaryOpExpression::type_check(CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(rhs); auto lhs_type = lhs->ret_type; auto rhs_type = rhs->ret_type; + auto error = [&]() { throw TaichiTypeError( fmt::format("unsupported operand type(s) for '{}': '{}' and '{}'", @@ -220,43 +221,62 @@ void BinaryOpExpression::type_check(CompileConfig *config) { TI_NOT_IMPLEMENTED; } + /* + Dtype inference for both TensorType and PrimitiveType follow are essentially + the same. Therefore we extract the primitive type to perform the type + inference, and then reconstruct the TensorType once neccessary. + */ + auto lhs_primitive_type = lhs->ret_type.get_element_type(); + auto rhs_primitive_type = rhs->ret_type.get_element_type(); + auto ret_primitive_type = ret_type; + if (binary_is_bitwise(type) && - (!is_integral(lhs_type) || !is_integral(rhs_type))) + (!is_integral(lhs_primitive_type) || !is_integral(rhs_primitive_type))) error(); - if (binary_is_logical(type) && - (lhs_type != PrimitiveType::i32 || rhs_type != PrimitiveType::i32)) + if (binary_is_logical(type) && (lhs_primitive_type != PrimitiveType::i32 || + rhs_primitive_type != PrimitiveType::i32)) error(); if (is_comparison(type) || binary_is_logical(type)) { - ret_type = PrimitiveType::i32; + ret_primitive_type = PrimitiveType::i32; return; } if (is_shift_op(type) || - (type == BinaryOpType::pow && is_integral(rhs_type))) { - ret_type = lhs_type; + (type == BinaryOpType::pow && is_integral(rhs_primitive_type))) { + ret_primitive_type = lhs_primitive_type; return; } // Some backends such as vulkan doesn't support fp64 // Try not promoting to fp64 unless necessary if (type == BinaryOpType::atan2) { - if (lhs_type == PrimitiveType::f64 || rhs_type == PrimitiveType::f64) { - ret_type = PrimitiveType::f64; + if (lhs_primitive_type == PrimitiveType::f64 || + rhs_primitive_type == PrimitiveType::f64) { + ret_primitive_type = PrimitiveType::f64; } else { - ret_type = PrimitiveType::f32; + ret_primitive_type = PrimitiveType::f32; } return; } if (type == BinaryOpType::truediv) { auto default_fp = config->default_fp; - if (!is_real(lhs_type)) { - lhs_type = default_fp; + if (!is_real(lhs_primitive_type)) { + lhs_primitive_type = default_fp; } - if (!is_real(rhs_type)) { - rhs_type = default_fp; + if (!is_real(rhs_primitive_type)) { + rhs_primitive_type = default_fp; } } - ret_type = promoted_type(lhs_type, rhs_type); + ret_primitive_type = promoted_type(lhs_primitive_type, rhs_primitive_type); + + if (rhs_type->is() && lhs_type->is()) { + ret_type = taichi::lang::TypeFactory::get_instance().get_tensor_type( + rhs_type.get_shape(), ret_primitive_type); + } else if (rhs_type->is() && lhs_type->is()) { + ret_type = ret_primitive_type; + } else { + TI_NOT_IMPLEMENTED; + } } void BinaryOpExpression::flatten(FlattenContext *ctx) { diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 8f6afbac0d630..b7895349287bd 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -116,7 +116,6 @@ class Scalarize : public IRVisitor { auto matrix_init_stmt = std::make_unique(matrix_init_values); - matrix_init_stmt->ret_type = src_dtype; stmt->replace_usages_with(matrix_init_stmt.get()); @@ -178,6 +177,93 @@ class Scalarize : public IRVisitor { } } + /* + Before: + TensorType<4 x i32> val = BinaryStmt(TensorType<4 x i32> lhs, + TensorType<4 x i32> rhs) + + * Note that "lhs" and "rhs" should have already been scalarized to + MatrixInitStmt + + After: + i32 calc_val0 = BinaryStmt(lhs->cast()->val[0], + rhs->cast()->val[0]) + i32 calc_val1 = BinaryStmt(lhs->cast()->val[1], + rhs->cast()->val[1]) + i32 calc_val2 = BinaryStmt(lhs->cast()->val[2], + rhs->cast()->val[2]) + i32 calc_val3 = BinaryStmt(lhs->cast()->val[3], + rhs->cast()->val[3]) + + tmp = MatrixInitStmt(calc_val0, calc_val1, + calc_val2, calc_val3) + + stmt->replace_all_usages_with(tmp) + */ + void visit(BinaryOpStmt *stmt) override { + auto lhs_dtype = stmt->lhs->ret_type; + auto rhs_dtype = stmt->rhs->ret_type; + + if (lhs_dtype->is() || rhs_dtype->is()) { + int lhs_num_elements = 1; + int rhs_num_elements = 1; + if (lhs_dtype->is()) { + auto lhs_tensor_type = lhs_dtype->as(); + lhs_num_elements = lhs_tensor_type->get_num_elements(); + } + if (rhs_dtype->is()) { + auto rhs_tensor_type = rhs_dtype->as(); + rhs_num_elements = rhs_tensor_type->get_num_elements(); + } + + if (lhs_num_elements > 1 && rhs_num_elements > 1) + TI_ASSERT(lhs_num_elements == rhs_num_elements); + + size_t num_elements = std::max(lhs_num_elements, rhs_num_elements); + + std::vector lhs_vals(num_elements); + std::vector rhs_vals(num_elements); + + if (lhs_dtype->is()) { + TI_ASSERT(stmt->lhs->is()); + auto lhs_matrix_init_stmt = stmt->lhs->cast(); + lhs_vals = lhs_matrix_init_stmt->values; + } else { + for (size_t i = 0; i < num_elements; i++) { + lhs_vals[i] = stmt->lhs; + } + } + + if (rhs_dtype->is()) { + TI_ASSERT(stmt->rhs->is()); + auto rhs_matrix_init_stmt = stmt->rhs->cast(); + rhs_vals = rhs_matrix_init_stmt->values; + } else { + for (size_t i = 0; i < num_elements; i++) { + rhs_vals[i] = stmt->rhs; + } + } + + std::vector matrix_init_values; + for (size_t i = 0; i < num_elements; i++) { + auto binary_stmt = std::make_unique( + stmt->op_type, lhs_vals[i], rhs_vals[i]); + matrix_init_values.push_back(binary_stmt.get()); + + modifier_.insert_before(stmt, std::move(binary_stmt)); + } + + auto matrix_init_stmt = + std::make_unique(matrix_init_values); + matrix_init_stmt->ret_type = stmt->ret_type; + + stmt->replace_usages_with(matrix_init_stmt.get()); + modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + + modifier_.erase(stmt); + } + } + void visit(Block *stmt_list) override { for (auto &stmt : stmt_list->statements) { stmt->accept(this); diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 410557d43b0f4..5e1f333f2c3d3 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -877,3 +877,23 @@ def func(a: ti.types.ndarray()): assert (x[2] == [[-0., -1.], [-2., -3.]]) assert (x[3] == [[20.08553696, 54.59814835], [148.41316223, 403.42880249]]) assert (x[4] == [[4.48168898, 7.38905621], [12.18249416, 20.08553696]]) + + +@test_utils.test(arch=[ti.cuda, ti.cpu], + real_matrix=True, + real_matrix_scalarize=True) +def test_binary_op_scalarize(): + @ti.kernel + def func(a: ti.types.ndarray()): + a[0] = [[0., 1.], [2., 3.]] + a[1] = [[3., 4.], [5., 6.]] + a[2] = a[0] + a[0] + a[3] = a[1] * a[1] + a[4] = ti.max(a[2], a[3]) + + x = ti.Matrix.ndarray(2, 2, ti.f32, shape=5) + func(x) + + assert (x[2] == [[0., 2.], [4., 6.]]) + assert (x[3] == [[9., 16.], [25., 36.]]) + assert (x[4] == [[9., 16.], [25., 36.]]) From b37816fe1ac5787219dcdb5bf31df24191aee612 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Wed, 21 Sep 2022 14:17:31 +0800 Subject: [PATCH 2/4] [Lang] Fix invalid assertion for matrix values --- tests/python/test_matrix.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 410557d43b0f4..81d9ac969ce4b 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -25,6 +25,17 @@ ] +def check_matrix(mat): + if isinstance(mat, ti.lang.matrix.Vector): + assert all(mat == 1) + elif isinstance(mat, ti.lang.matrix.Matrix): + for i in range(mat.m): + for j in range(mat.n): + assert (mat[i, j] == 1) + else: + assert False + + @test_utils.test(arch=get_host_arch_list()) def test_python_scope_vector_operations(): for ops in vector_operation_types: @@ -831,11 +842,11 @@ def func(a: ti.types.ndarray()): x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5) func(x) - assert (x[0] == [[0, 1], [2, 3]]) - assert (x[1] == [[1, 2], [3, 4]]) - assert (x[2] == [[2, 3], [4, 5]]) - assert (x[3] == [[3, 4], [5, 6]]) - assert (x[4] == [[4, 5], [6, 7]]) + check_matrix(x[0] == [[0, 1], [2, 3]]) + check_matrix(x[1] == [[1, 2], [3, 4]]) + check_matrix(x[2] == [[2, 3], [4, 5]]) + check_matrix(x[3] == [[3, 4], [5, 6]]) + check_matrix(x[4] == [[4, 5], [6, 7]]) @test_utils.test(arch=[ti.cuda, ti.cpu], @@ -853,8 +864,8 @@ def func(a: ti.types.ndarray()): x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5) func(x) - assert (x[3] == [[1, 2], [3, 4]]) - assert (x[4] == [[2, 3], [4, 5]]) + check_matrix(x[3] == [[1, 2], [3, 4]]) + check_matrix(x[4] == [[2, 3], [4, 5]]) @test_utils.test(arch=[ti.cuda, ti.cpu], @@ -872,8 +883,10 @@ def func(a: ti.types.ndarray()): x = ti.Matrix.ndarray(2, 2, ti.f32, shape=5) func(x) - assert (x[0] == [[0., 1.], [2., 3.]]) - assert (x[1] == [[3., 4.], [5., 6.]]) - assert (x[2] == [[-0., -1.], [-2., -3.]]) - assert (x[3] == [[20.08553696, 54.59814835], [148.41316223, 403.42880249]]) - assert (x[4] == [[4.48168898, 7.38905621], [12.18249416, 20.08553696]]) + check_matrix(x[0] == [[0., 1.], [2., 3.]]) + check_matrix(x[1] == [[3., 4.], [5., 6.]]) + check_matrix(x[2] == [[-0., -1.], [-2., -3.]]) + check_matrix(x[3] < [[20.086, 54.60], [148.42, 403.43]]) + check_matrix(x[3] > [[20.085, 54.59], [148.41, 403.42]]) + check_matrix(x[4] < [[4.49, 7.39], [12.19, 20.09]]) + check_matrix(x[4] > [[4.48, 7.38], [12.18, 20.08]]) From aff7540e3e7cc9255c13bb01a1862a2d0fe5c85e Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Wed, 21 Sep 2022 14:17:31 +0800 Subject: [PATCH 3/4] [Lang] Fix invalid assertion for matrix values --- tests/python/test_matrix.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 410557d43b0f4..ddfa9b66b71bc 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -831,11 +831,11 @@ def func(a: ti.types.ndarray()): x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5) func(x) - assert (x[0] == [[0, 1], [2, 3]]) - assert (x[1] == [[1, 2], [3, 4]]) - assert (x[2] == [[2, 3], [4, 5]]) - assert (x[3] == [[3, 4], [5, 6]]) - assert (x[4] == [[4, 5], [6, 7]]) + assert (x[0] == [[0, 1], [2, 3]]).all() + assert (x[1] == [[1, 2], [3, 4]]).all() + assert (x[2] == [[2, 3], [4, 5]]).all() + assert (x[3] == [[3, 4], [5, 6]]).all() + assert (x[4] == [[4, 5], [6, 7]]).all() @test_utils.test(arch=[ti.cuda, ti.cpu], @@ -853,8 +853,8 @@ def func(a: ti.types.ndarray()): x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5) func(x) - assert (x[3] == [[1, 2], [3, 4]]) - assert (x[4] == [[2, 3], [4, 5]]) + assert (x[3] == [[1, 2], [3, 4]]).all() + assert (x[4] == [[2, 3], [4, 5]]).all() @test_utils.test(arch=[ti.cuda, ti.cpu], @@ -872,8 +872,10 @@ def func(a: ti.types.ndarray()): x = ti.Matrix.ndarray(2, 2, ti.f32, shape=5) func(x) - assert (x[0] == [[0., 1.], [2., 3.]]) - assert (x[1] == [[3., 4.], [5., 6.]]) - assert (x[2] == [[-0., -1.], [-2., -3.]]) - assert (x[3] == [[20.08553696, 54.59814835], [148.41316223, 403.42880249]]) - assert (x[4] == [[4.48168898, 7.38905621], [12.18249416, 20.08553696]]) + assert (x[0] == [[0., 1.], [2., 3.]]).all() + assert (x[1] == [[3., 4.], [5., 6.]]).all() + assert (x[2] == [[-0., -1.], [-2., -3.]]).all() + assert (x[3] < [[20.086, 54.60], [148.42, 403.43]]).all() + assert (x[3] > [[20.085, 54.59], [148.41, 403.42]]).all() + assert (x[4] < [[4.49, 7.39], [12.19, 20.09]]).all() + assert (x[4] > [[4.48, 7.38], [12.18, 20.08]]).all() From 0432ff88209b9bf6748b8a1746841ac7bedab55d Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 26 Sep 2022 10:06:48 +0800 Subject: [PATCH 4/4] Bug fix --- python/taichi/lang/matrix.py | 8 ++---- taichi/ir/frontend_ir.cpp | 2 +- taichi/transforms/scalarize.cpp | 50 ++++++++++----------------------- tests/python/test_matrix.py | 4 +-- 4 files changed, 21 insertions(+), 43 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index ec28e52abe18f..c1e148ef00584 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -14,9 +14,9 @@ TaichiTypeError) from taichi.lang.field import Field, ScalarField, SNodeHostAccess from taichi.lang.swizzle_generator import SwizzleGenerator -from taichi.lang.util import (cook_dtype, in_python_scope, in_taichi_scope, - python_scope, taichi_scope, to_numpy_type, - to_paddle_type, to_pytorch_type, warning) +from taichi.lang.util import (cook_dtype, in_python_scope, python_scope, + taichi_scope, to_numpy_type, to_paddle_type, + to_pytorch_type, warning) from taichi.types import primitive_types from taichi.types.compound_types import CompoundType, TensorType @@ -502,8 +502,6 @@ def __init__(self, def _element_wise_binary(self, foo, other): other = self._broadcast_copy(other) - if in_taichi_scope() and impl.current_cfg().real_matrix: - return foo(self, other) if is_col_vector(self): return Vector([foo(self(i), other(i)) for i in range(self.n)], ndim=self.ndim) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index a26fa9979ec77..2f18fee99466f 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -229,7 +229,6 @@ void BinaryOpExpression::type_check(CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(rhs); auto lhs_type = lhs->ret_type; auto rhs_type = rhs->ret_type; - auto error = [&]() { throw TaichiTypeError( fmt::format("unsupported operand type(s) for '{}': '{}' and '{}'", @@ -318,6 +317,7 @@ void BinaryOpExpression::type_check(CompileConfig *config) { rhs_type = make_dt(default_fp); } } + ret_type = promoted_type(lhs_type, rhs_type); } void BinaryOpExpression::flatten(FlattenContext *ctx) { diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index e1f28e9e16074..e9cf17fe692f8 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -204,46 +204,26 @@ class Scalarize : public IRVisitor { auto lhs_dtype = stmt->lhs->ret_type; auto rhs_dtype = stmt->rhs->ret_type; - if (lhs_dtype->is() || rhs_dtype->is()) { - int lhs_num_elements = 1; - int rhs_num_elements = 1; - if (lhs_dtype->is()) { - auto lhs_tensor_type = lhs_dtype->as(); - lhs_num_elements = lhs_tensor_type->get_num_elements(); - } - if (rhs_dtype->is()) { - auto rhs_tensor_type = rhs_dtype->as(); - rhs_num_elements = rhs_tensor_type->get_num_elements(); - } + // BinaryOpExpression::type_check() should have taken care of the + // broadcasting and neccessary conversions. So we simply add an assertion + // here to make sure that the operands are of the same shape and dtype + TI_ASSERT(lhs_dtype == rhs_dtype); - if (lhs_num_elements > 1 && rhs_num_elements > 1) - TI_ASSERT(lhs_num_elements == rhs_num_elements); + if (lhs_dtype->is() && rhs_dtype->is()) { + // Scalarization for LoadStmt should have already replaced both operands + // to MatrixInitStmt + TI_ASSERT(stmt->lhs->is()); + TI_ASSERT(stmt->rhs->is()); - size_t num_elements = std::max(lhs_num_elements, rhs_num_elements); + auto lhs_matrix_init_stmt = stmt->lhs->cast(); + std::vector lhs_vals = lhs_matrix_init_stmt->values; - std::vector lhs_vals(num_elements); - std::vector rhs_vals(num_elements); + auto rhs_matrix_init_stmt = stmt->rhs->cast(); + std::vector rhs_vals = rhs_matrix_init_stmt->values; - if (lhs_dtype->is()) { - TI_ASSERT(stmt->lhs->is()); - auto lhs_matrix_init_stmt = stmt->lhs->cast(); - lhs_vals = lhs_matrix_init_stmt->values; - } else { - for (size_t i = 0; i < num_elements; i++) { - lhs_vals[i] = stmt->lhs; - } - } - - if (rhs_dtype->is()) { - TI_ASSERT(stmt->rhs->is()); - auto rhs_matrix_init_stmt = stmt->rhs->cast(); - rhs_vals = rhs_matrix_init_stmt->values; - } else { - for (size_t i = 0; i < num_elements; i++) { - rhs_vals[i] = stmt->rhs; - } - } + TI_ASSERT(rhs_vals.size() == lhs_vals.size()); + size_t num_elements = lhs_vals.size(); std::vector matrix_init_values; for (size_t i = 0; i < num_elements; i++) { auto binary_stmt = std::make_unique( diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 60bbc166d156d..74ddb3f706617 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -976,8 +976,8 @@ def verify(x): real_matrix=True, real_matrix_scalarize=True) def test_binary_op_scalarize(): - @ti.kernel - def func(a: ti.types.ndarray()): + @ti.func + def func(a: ti.template()): a[0] = [[0., 1.], [2., 3.]] a[1] = [[3., 4.], [5., 6.]] a[2] = a[0] + a[0]