Skip to content

Commit

Permalink
[Lang] MatrixNdarray refactor part6: Add scalarization for LocalLoadS…
Browse files Browse the repository at this point in the history
…tmt & GlobalLoadStmt with TensorType (#6024)

Related issue = #5873,
#5819

This PR is working "Part ④" in
#5873.
  • Loading branch information
jim19930609 authored Sep 16, 2022
1 parent 8fd7522 commit e23ad2d
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 10 deletions.
9 changes: 8 additions & 1 deletion python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.field import Field
from taichi.lang.impl import current_cfg
from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl,
_TiScopeMatrixImpl)
_TiScopeMatrixImpl, make_matrix)
from taichi.lang.snode import append
from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type
from taichi.types import (annotations, ndarray_type, primitive_types,
Expand Down Expand Up @@ -114,6 +115,12 @@ def build_Assign(ctx, node):
@staticmethod
def build_assign_slice(ctx, node_target, values, is_static_assign):
target = ASTTransformer.build_Subscript(ctx, node_target, get_ref=True)
if current_cfg().real_matrix:
if isinstance(node_target.value.ptr,
any_array.AnyArray) and isinstance(
values, (list, tuple)):
values = make_matrix(values)

if isinstance(node_target.value.ptr, Matrix):
if isinstance(node_target.value.ptr._impl, _TiScopeMatrixImpl):
target._assign(values)
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from taichi.lang.kernel_arguments import SparseMatrixProxy
from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType,
Vector, _IntermediateMatrix,
_MatrixFieldElement, make_matrix)
_MatrixFieldElement)
from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance,
MeshRelationAccessProxy,
MeshReorderedMatrixFieldProxy,
Expand Down
1 change: 1 addition & 0 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
serializer(config->experimental_auto_mesh_local);
serializer(config->auto_mesh_local_default_occupacy);
serializer(config->real_matrix);
serializer(config->real_matrix_scalarize);
serializer.finalize();

return serializer.data;
Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ void export_lang(py::module &m) {
&CompileConfig::ndarray_use_cached_allocator)
.def_readwrite("use_mesh", &CompileConfig::use_mesh)
.def_readwrite("real_matrix", &CompileConfig::real_matrix)
.def_readwrite("real_matrix_scalarize",
&CompileConfig::real_matrix_scalarize)
.def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd)
.def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd)
.def_readwrite("quant_opt_store_fusion",
Expand Down
79 changes: 74 additions & 5 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ TLANG_NAMESPACE_BEGIN

class Scalarize : public IRVisitor {
public:
DelayedIRModifier modifier_;

Scalarize(IRNode *node) {
allow_undefined_visitor = true;
invoke_default_visitor = false;
node->accept(this);

modifier_.modify_ir();
}

/*
Expand Down Expand Up @@ -51,18 +55,75 @@ class Scalarize : public IRVisitor {
int num_elements = val_tensor_type->get_num_elements();
for (int i = 0; i < num_elements; i++) {
auto const_stmt = std::make_unique<ConstStmt>(
TypedConstant(stmt->val->ret_type.get_element_type(), i));
TypedConstant(get_data_type<int32>(), i));

auto ptr_offset_stmt =
std::make_unique<PtrOffsetStmt>(stmt->dest, const_stmt.get());
auto scalarized_stmt = std::make_unique<T>(ptr_offset_stmt.get(),
matrix_init_stmt->values[i]);

stmt->insert_before_me(std::move(const_stmt));
stmt->insert_before_me(std::move(ptr_offset_stmt));
stmt->insert_before_me(std::move(scalarized_stmt));
modifier_.insert_before(stmt, std::move(const_stmt));
modifier_.insert_before(stmt, std::move(ptr_offset_stmt));
modifier_.insert_before(stmt, std::move(scalarized_stmt));
}
modifier_.erase(stmt);
}
}

/*
Before:
TensorType<4 x i32> val = LoadStmt(TensorType<4 x i32>* src)
After:
i32* addr0 = PtrOffsetStmt(TensorType<4 x i32>* src, 0)
i32* addr1 = PtrOffsetStmt(TensorType<4 x i32>* src, 1)
i32* addr2 = PtrOffsetStmt(TensorType<4 x i32>* src, 2)
i32* addr3 = PtrOffsetStmt(TensorType<4 x i32>* src, 3)
i32 val0 = LoadStmt(addr0)
i32 val1 = LoadStmt(addr1)
i32 val2 = LoadStmt(addr2)
i32 val3 = LoadStmt(addr3)
tmp = MatrixInitStmt(val0, val1, val2, val3)
stmt->replace_all_usages_with(tmp)
*/
template <typename T>
void scalarize_load_stmt(T *stmt) {
auto src_dtype = stmt->src->ret_type.ptr_removed();
if (src_dtype->template is<TensorType>()) {
// Needs scalarize
auto src_tensor_type = src_dtype->template as<TensorType>();

std::vector<Stmt *> matrix_init_values;
int num_elements = src_tensor_type->get_num_elements();

for (size_t i = 0; i < num_elements; i++) {
auto const_stmt = std::make_unique<ConstStmt>(
TypedConstant(get_data_type<int32>(), i));

auto ptr_offset_stmt =
std::make_unique<PtrOffsetStmt>(stmt->src, const_stmt.get());
auto scalarized_stmt = std::make_unique<T>(ptr_offset_stmt.get());

matrix_init_values.push_back(scalarized_stmt.get());

modifier_.insert_before(stmt, std::move(const_stmt));
modifier_.insert_before(stmt, std::move(ptr_offset_stmt));
modifier_.insert_before(stmt, std::move(scalarized_stmt));
}
stmt->parent->erase(stmt);

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);

matrix_init_stmt->ret_type = src_dtype;

stmt->replace_usages_with(matrix_init_stmt.get());
modifier_.insert_before(stmt, std::move(matrix_init_stmt));

modifier_.erase(stmt);
}
}

Expand Down Expand Up @@ -107,6 +168,14 @@ class Scalarize : public IRVisitor {
void visit(LocalStoreStmt *stmt) override {
scalarize_store_stmt<LocalStoreStmt>(stmt);
}

void visit(GlobalLoadStmt *stmt) override {
scalarize_load_stmt<GlobalLoadStmt>(stmt);
}

void visit(LocalLoadStmt *stmt) override {
scalarize_load_stmt<LocalLoadStmt>(stmt);
}
};

namespace irpass {
Expand Down
66 changes: 63 additions & 3 deletions tests/cpp/transforms/scalarize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace lang {

// Basic tests within a basic block
template <typename T>
void test_scalarize() {
void test_store_scalarize() {
TestProgram test_prog;
test_prog.setup();

Expand Down Expand Up @@ -76,9 +76,69 @@ void test_scalarize() {
EXPECT_EQ(block->statements[16]->is<T>(), true);
}

template <typename T>
void test_load_scalarize() {
TestProgram test_prog;
test_prog.setup();

auto block = std::make_unique<Block>();

auto func = []() {};
auto kernel =
std::make_unique<Kernel>(*test_prog.prog(), func, "fake_kernel");
block->kernel = kernel.get();

auto &type_factory = TypeFactory::get_instance();

/*
TensorType<4 x i32>* %1 = ExternalPtrStmt()
TensorType<4 x i32> %2 = LoadStmt(%1)
*/
Type *tensor_type = type_factory.get_tensor_type(
{2, 2}, type_factory.get_primitive_type(PrimitiveTypeID::i32));
auto argload_stmt = block->push_back<ArgLoadStmt>(0 /*arg_id*/, tensor_type);

std::vector<Stmt *> indices = {};
Stmt *src_stmt = block->push_back<ExternalPtrStmt>(
argload_stmt, indices); // fake ExternalPtrStmt
src_stmt->ret_type = type_factory.get_pointer_type(tensor_type);

block->push_back<T>(src_stmt);

irpass::scalarize(block.get());

EXPECT_EQ(block->size(), 1 /*argload*/ + 1 /*external_ptr*/ + 4 /*const*/ +
4 /*ptroffset*/ + 4 /*load*/ +
1 /*matrix_init*/);

// Check for scalarized statements
EXPECT_EQ(block->statements[2]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[3]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[4]->is<T>(), true);

EXPECT_EQ(block->statements[5]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[6]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[7]->is<T>(), true);

EXPECT_EQ(block->statements[8]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[9]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[10]->is<T>(), true);

EXPECT_EQ(block->statements[11]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[12]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[13]->is<T>(), true);

EXPECT_EQ(block->statements[14]->is<MatrixInitStmt>(), true);
}

TEST(Scalarize, ScalarizeStore) {
test_scalarize<GlobalStoreStmt>();
test_scalarize<LocalStoreStmt>();
test_store_scalarize<GlobalStoreStmt>();
test_store_scalarize<LocalStoreStmt>();
}

TEST(Scalarize, ScalarizeLoad) {
test_load_scalarize<GlobalLoadStmt>();
test_load_scalarize<LocalLoadStmt>();
}

} // namespace lang
Expand Down
38 changes: 38 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,3 +817,41 @@ def foo() -> ti.types.matrix(2, 2, ti.f32):
return a @ b.transpose()

assert foo() == [[1.0, 2.0], [2.0, 4.0]]


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True)
def test_store_scalarize():
@ti.kernel
def func(a: ti.types.ndarray()):
for i in range(5):
a[i] = [[i, i + 1], [i + 2, i + 3]]

x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5)
func(x)

assert (x[0] == [[0, 1], [2, 3]])
assert (x[1] == [[1, 2], [3, 4]])
assert (x[2] == [[2, 3], [4, 5]])
assert (x[3] == [[3, 4], [5, 6]])
assert (x[4] == [[4, 5], [6, 7]])


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True)
def test_load_store_scalarize():
@ti.kernel
def func(a: ti.types.ndarray()):
for i in range(3):
a[i] = [[i, i + 1], [i + 2, i + 3]]

a[3] = a[1]
a[4] = a[2]

x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5)
func(x)

assert (x[3] == [[1, 2], [3, 4]])
assert (x[4] == [[2, 3], [4, 5]])

0 comments on commit e23ad2d

Please sign in to comment.