Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] MatrixNdarray refactor part8: Add scalarization for BinaryOpStmt with TensorType-operands #6086

Merged
merged 8 commits into from
Sep 26, 2022
2 changes: 2 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,14 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) {
auto ret = ctx->push_back<LocalLoadStmt>(result);
ret->tb = tb;
stmt = ret;
stmt->ret_type = ret_type;
return;
}
flatten_rvalue(rhs, ctx);
ctx->push_back(std::make_unique<BinaryOpStmt>(type, lhs->stmt, rhs->stmt));
ctx->stmts.back()->tb = tb;
stmt = ctx->back_stmt();
stmt->ret_type = ret_type;
}

void make_ifte(Expression::FlattenContext *ctx,
Expand Down
68 changes: 67 additions & 1 deletion taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class Scalarize : public IRVisitor {

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);

matrix_init_stmt->ret_type = src_dtype;

stmt->replace_usages_with(matrix_init_stmt.get());
Expand Down Expand Up @@ -178,6 +177,73 @@ class Scalarize : public IRVisitor {
}
}

/*
Before:
TensorType<4 x i32> val = BinaryStmt(TensorType<4 x i32> lhs,
TensorType<4 x i32> rhs)

* Note that "lhs" and "rhs" should have already been scalarized to
MatrixInitStmt

After:
i32 calc_val0 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[0],
rhs->cast<MatrixInitStmt>()->val[0])
i32 calc_val1 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[1],
rhs->cast<MatrixInitStmt>()->val[1])
i32 calc_val2 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[2],
rhs->cast<MatrixInitStmt>()->val[2])
i32 calc_val3 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[3],
rhs->cast<MatrixInitStmt>()->val[3])

tmp = MatrixInitStmt(calc_val0, calc_val1,
calc_val2, calc_val3)

stmt->replace_all_usages_with(tmp)
*/
void visit(BinaryOpStmt *stmt) override {
auto lhs_dtype = stmt->lhs->ret_type;
auto rhs_dtype = stmt->rhs->ret_type;

// BinaryOpExpression::type_check() should have taken care of the
// broadcasting and neccessary conversions. So we simply add an assertion
// here to make sure that the operands are of the same shape and dtype
TI_ASSERT(lhs_dtype == rhs_dtype);

if (lhs_dtype->is<TensorType>() && rhs_dtype->is<TensorType>()) {
// Scalarization for LoadStmt should have already replaced both operands
// to MatrixInitStmt
TI_ASSERT(stmt->lhs->is<MatrixInitStmt>());
TI_ASSERT(stmt->rhs->is<MatrixInitStmt>());

auto lhs_matrix_init_stmt = stmt->lhs->cast<MatrixInitStmt>();
std::vector<Stmt *> lhs_vals = lhs_matrix_init_stmt->values;

auto rhs_matrix_init_stmt = stmt->rhs->cast<MatrixInitStmt>();
std::vector<Stmt *> rhs_vals = rhs_matrix_init_stmt->values;

TI_ASSERT(rhs_vals.size() == lhs_vals.size());

size_t num_elements = lhs_vals.size();
std::vector<Stmt *> matrix_init_values;
for (size_t i = 0; i < num_elements; i++) {
auto binary_stmt = std::make_unique<BinaryOpStmt>(
stmt->op_type, lhs_vals[i], rhs_vals[i]);
matrix_init_values.push_back(binary_stmt.get());

modifier_.insert_before(stmt, std::move(binary_stmt));
}

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);
matrix_init_stmt->ret_type = stmt->ret_type;

stmt->replace_usages_with(matrix_init_stmt.get());
modifier_.insert_before(stmt, std::move(matrix_init_stmt));

modifier_.erase(stmt);
}
}

void visit(Block *stmt_list) override {
for (auto &stmt : stmt_list->statements) {
stmt->accept(this);
Expand Down
22 changes: 22 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,25 @@ def verify(x):
field = ti.Matrix.field(2, 2, ti.f32, shape=5)
ndarray = ti.Matrix.ndarray(2, 2, ti.f32, shape=5)
_test_field_and_ndarray(field, ndarray, func, verify)


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True)
def test_binary_op_scalarize():
@ti.func
def func(a: ti.template()):
a[0] = [[0., 1.], [2., 3.]]
a[1] = [[3., 4.], [5., 6.]]
a[2] = a[0] + a[0]
a[3] = a[1] * a[1]
a[4] = ti.max(a[2], a[3])

def verify(x):
assert (x[2] == [[0., 2.], [4., 6.]]).all()
assert (x[3] == [[9., 16.], [25., 36.]]).all()
assert (x[4] == [[9., 16.], [25., 36.]]).all()

field = ti.Matrix.field(2, 2, ti.f32, shape=5)
ndarray = ti.Matrix.ndarray(2, 2, ti.f32, shape=5)
_test_field_and_ndarray(field, ndarray, func, verify)