From b7f99e0f99887390d43e3e65cdec784d7b3bf9ce Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Thu, 13 Oct 2022 16:11:10 +0800 Subject: [PATCH 1/2] [Lang] MatrixNdarray refactor part13: Add scalarization for TernaryOpStmt --- taichi/ir/frontend_ir.cpp | 50 ++++++++++++++++++--- taichi/transforms/scalarize.cpp | 78 +++++++++++++++++++++++++++++++++ tests/python/test_matrix.py | 20 +++++++++ 3 files changed, 143 insertions(+), 5 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 53125dd2c71e7..8c3307851b2cf 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -391,24 +391,63 @@ void make_ifte(Expression::FlattenContext *ctx, return; } -void TernaryOpExpression::type_check(CompileConfig *) { +void TernaryOpExpression::type_check(CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(op1); TI_ASSERT_TYPE_CHECKED(op2); TI_ASSERT_TYPE_CHECKED(op3); auto op1_type = op1->ret_type; auto op2_type = op2->ret_type; auto op3_type = op3->ret_type; + auto error = [&]() { throw TaichiTypeError( fmt::format("unsupported operand type(s) for '{}': '{}', '{}' and '{}'", ternary_type_name(type), op1->ret_type->to_string(), op2->ret_type->to_string(), op3->ret_type->to_string())); }; - if (op1_type != PrimitiveType::i32) - error(); - if (!op2_type->is() || !op3_type->is()) + + bool is_valid = true; + bool is_tensor = false; + if (op1_type->is() && op2_type->is() && + op3_type->is()) { + // valid + is_tensor = true; + if (op1_type->cast()->get_shape().size() != + op2_type->cast()->get_shape().size()) { + is_valid = false; + } + if (op2_type->cast()->get_shape().size() != + op3_type->cast()->get_shape().size()) { + is_valid = false; + } + op1_type = op1_type->cast()->get_element_type(); + op2_type = op2_type->cast()->get_element_type(); + op3_type = op3_type->cast()->get_element_type(); + + } else if (op1_type->is() && op2_type->is() && + op3_type->is()) { + // valid + } else { + is_valid = false; + } + + if (op1_type != PrimitiveType::i32) { + is_valid = false; + } + if (!op2_type->is() || !op3_type->is()) { + is_valid = false; + } + + if (!is_valid) error(); - ret_type = promoted_type(op2_type, op3_type); + + if (is_tensor) { + auto primitive_dtype = promoted_type(op2_type, op3_type); + ret_type = TypeFactory::create_tensor_type( + op2->ret_type->cast()->get_shape(), primitive_dtype); + } else { + ret_type = promoted_type(op2_type, op3_type); + } } void TernaryOpExpression::flatten(FlattenContext *ctx) { @@ -425,6 +464,7 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) { } stmt = ctx->back_stmt(); stmt->tb = tb; + stmt->ret_type = ret_type; } void InternalFuncCallExpression::type_check(CompileConfig *) { diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 7632011b79e83..ee7475df972f6 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -258,6 +258,84 @@ class Scalarize : public BasicStmtVisitor { } } + /* + Before: + TensorType<4 x i32> val = TernaryStmt(TensorType<4 x i32> cond, + TensorType<4 x i32> lhs, + TensorType<4 x i32> rhs) + + After: + i32 val0 = TernaryStmt(cond->cast()->val[0], + lhs->cast()->val[0], + rhs->cast()->val[0]) + + i32 val1 = TernaryStmt(cond->cast()->val[1], + lhs->cast()->val[1], + rhs->cast()->val[1]) + + i32 val2 = TernaryStmt(cond->cast()->val[2], + lhs->cast()->val[2], + rhs->cast()->val[2]) + + i32 val3 = TernaryStmt(cond->cast()->val[3], + lhs->cast()->val[3], + rhs->cast()->val[3]) + + tmp = MatrixInitStmt(val0, val1, val2, val3) + + stmt->replace_all_usages_with(tmp) + */ + void visit(TernaryOpStmt *stmt) override { + auto cond_dtype = stmt->op1->ret_type; + auto op2_dtype = stmt->op2->ret_type; + auto op3_dtype = stmt->op3->ret_type; + + if (cond_dtype->is() && op2_dtype->is() && + op3_dtype->is()) { + return; + } + + if (cond_dtype->is() && op2_dtype->is() && + op3_dtype->is()) { + TI_ASSERT(stmt->op1->is()); + TI_ASSERT(stmt->op2->is()); + TI_ASSERT(stmt->op3->is()); + + auto cond_matrix_init_stmt = stmt->op1->cast(); + std::vector cond_vals = cond_matrix_init_stmt->values; + + auto op2_matrix_init_stmt = stmt->op2->cast(); + std::vector op2_vals = op2_matrix_init_stmt->values; + + auto op3_matrix_init_stmt = stmt->op3->cast(); + std::vector op3_vals = op3_matrix_init_stmt->values; + + TI_ASSERT(cond_vals.size() == op2_vals.size()); + TI_ASSERT(op2_vals.size() == op3_vals.size()); + + size_t num_elements = cond_vals.size(); + auto primitive_type = stmt->ret_type.get_element_type(); + std::vector matrix_init_values; + for (size_t i = 0; i < num_elements; i++) { + auto ternary_stmt = std::make_unique( + stmt->op_type, cond_vals[i], op2_vals[i], op3_vals[i]); + matrix_init_values.push_back(ternary_stmt.get()); + ternary_stmt->ret_type = primitive_type; + + modifier_.insert_before(stmt, std::move(ternary_stmt)); + } + + auto matrix_init_stmt = + std::make_unique(matrix_init_values); + matrix_init_stmt->ret_type = stmt->ret_type; + + stmt->replace_usages_with(matrix_init_stmt.get()); + modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + + modifier_.erase(stmt); + } + } + void visit(GlobalStoreStmt *stmt) override { scalarize_store_stmt(stmt); } diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 9e55c1e24797f..7fe083e74a74c 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -1044,3 +1044,23 @@ def verify(x): field = ti.Matrix.field(2, 2, ti.f32, shape=5) ndarray = ti.Matrix.ndarray(2, 2, ti.f32, shape=5) _test_field_and_ndarray(field, ndarray, func, verify) + + +@test_utils.test(arch=[ti.cuda, ti.cpu], + real_matrix=True, + real_matrix_scalarize=True, + debug=True) +def test_ternary_op_scalarize(): + @ti.kernel + def test(): + cond = ti.Vector([1, 0, 1]) + x = ti.Vector([3, 3, 3]) + y = ti.Vector([5, 5, 5]) + + z = ti.select(cond, x, y) + + assert z[0] == 3 + assert z[1] == 5 + assert z[2] == 3 + + test() From 4126d824e459e34ed902b5544153717b4eb37263 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 17 Oct 2022 10:05:03 +0800 Subject: [PATCH 2/2] Add comments --- taichi/ir/frontend_ir.cpp | 8 ++++---- taichi/transforms/scalarize.cpp | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 8c3307851b2cf..d719262c346a9 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -412,12 +412,12 @@ void TernaryOpExpression::type_check(CompileConfig *config) { op3_type->is()) { // valid is_tensor = true; - if (op1_type->cast()->get_shape().size() != - op2_type->cast()->get_shape().size()) { + if (op1_type->cast()->get_shape() != + op2_type->cast()->get_shape()) { is_valid = false; } - if (op2_type->cast()->get_shape().size() != - op3_type->cast()->get_shape().size()) { + if (op2_type->cast()->get_shape() != + op3_type->cast()->get_shape()) { is_valid = false; } op1_type = op1_type->cast()->get_element_type(); diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index ee7475df972f6..95769b4868115 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -295,6 +295,12 @@ class Scalarize : public BasicStmtVisitor { return; } + // TernaryOpExpression::type_check() have taken care of the broadcasting, + // but the type conversions are delayed until irpass::type_check(). + // So we only check for the shape here. + TI_ASSERT(cond_dtype.get_shape() == op2_dtype.get_shape()); + TI_ASSERT(op2_dtype.get_shape() == op3_dtype.get_shape()); + if (cond_dtype->is() && op2_dtype->is() && op3_dtype->is()) { TI_ASSERT(stmt->op1->is());