Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] MatrixNdarray refactor part6: Add scalarization for LocalLoadStmt & GlobalLoadStmt with TensorType #6024

Merged
merged 61 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
6739192
[Lang] MatrixNdarray refactor part0: Support direct TensorType constr…
jim19930609 Aug 24, 2022
613a48d
Fixed minor issue
jim19930609 Aug 25, 2022
ef25863
Fixed CI failures
jim19930609 Aug 25, 2022
46aa538
Minor refactor
jim19930609 Aug 25, 2022
88dfeba
Fixed minor issue
jim19930609 Aug 25, 2022
1267793
[Lang] MatrixNdarray refactor part1: Refactored Taichi kernel argumen…
jim19930609 Aug 25, 2022
cb1c463
Fixed CI failure with Metal backend
jim19930609 Aug 26, 2022
8817413
Addressed review comments
jim19930609 Aug 26, 2022
2d40772
Merge branch 'matrix_ndarray_pr1' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
834cf67
Fixed format issue with clang-tidy
jim19930609 Aug 26, 2022
31235c5
Review comments
jim19930609 Aug 26, 2022
a4aa0f3
Merge branch 'matrix_ndarray_pr1' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
a9fea93
[Lang] MatrixNdarray refactor part2: Remove redundant members in pyth…
jim19930609 Aug 26, 2022
010dc7e
Merge branch 'matrix_ndarray_pr2' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
126676f
Fixed CI failure
jim19930609 Aug 26, 2022
ab97396
Fix CI failures
jim19930609 Aug 26, 2022
76cfae0
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Aug 26, 2022
9f8bfd0
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Aug 26, 2022
b6d8735
Merge branch 'matrix_ndarray_pr2' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
7f36e82
Renamed interface
jim19930609 Aug 26, 2022
a8a15ce
Merge branch 'matrix_ndarray_pr2' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
1bbe024
Minor bug fix
jim19930609 Aug 26, 2022
87f51da
Minor bug fix
jim19930609 Aug 26, 2022
ddd4544
Merge branch 'matrix_ndarray_pr3' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
73d3fce
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Aug 29, 2022
e56c727
[Lang] MatrixNdarray refactor part3: Enable TensorType for MatrixNdar…
jim19930609 Aug 29, 2022
235d889
Merge branch 'matrix_ndarray_pr3' of github.com:jim19930609/taichi in…
jim19930609 Aug 29, 2022
2fed7a8
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Aug 30, 2022
d42df6c
[Lang] MatrixNdarray refactor part4: Lowered TensorType to CHI IR lev…
jim19930609 Sep 1, 2022
3fe45de
Adjust code generation logic
jim19930609 Sep 1, 2022
f4cb4f3
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 1, 2022
ecf213a
Bug fix
jim19930609 Sep 1, 2022
4856194
[Lang] MatrixNdarray refactor part5: Add scalarization for LocalStore…
jim19930609 Sep 1, 2022
31c04ac
Debuging LLVM15 on windows
jim19930609 Sep 2, 2022
98c9b4d
Fixed LLVM15 issues
jim19930609 Sep 2, 2022
a006b5e
Minor fix
jim19930609 Sep 2, 2022
18cecc2
Merge branch 'matrix_ndarray_pr5' of github.com:jim19930609/taichi in…
jim19930609 Sep 2, 2022
2eed5a0
Fixed compilation issues
jim19930609 Sep 2, 2022
654c169
Merge branch 'matrix_ndarray_pr5' of github.com:jim19930609/taichi in…
jim19930609 Sep 2, 2022
81d298e
Add unit tests
jim19930609 Sep 2, 2022
89c7e07
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 2, 2022
22ec49a
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 6, 2022
829c5fb
Bug fix
jim19930609 Sep 6, 2022
d3b9232
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 8, 2022
1774e9c
Adjust type promotion logics
jim19930609 Sep 9, 2022
6209de2
[Lang] MatrixNdarray refactor part6: Add scalarization for LocalLoadS…
jim19930609 Sep 9, 2022
262dbf0
Bug fix
jim19930609 Sep 9, 2022
858e17d
Bug fix
jim19930609 Sep 9, 2022
f9a7bda
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 9, 2022
b66ab75
Merge branch 'matrix_ndarray_pr6' of github.com:jim19930609/taichi in…
jim19930609 Sep 9, 2022
1d93df4
Turn-off scalarization by default
jim19930609 Sep 10, 2022
42197f7
Merge branch 'matrix_ndarray_pr6' of github.com:jim19930609/taichi in…
jim19930609 Sep 10, 2022
2dd2a26
Bug fix
jim19930609 Sep 10, 2022
c8aa4d9
Minor fix
jim19930609 Sep 11, 2022
195987b
Merge branch 'matrix_ndarray_pr6' of github.com:jim19930609/taichi in…
jim19930609 Sep 11, 2022
2aa707d
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 15, 2022
022f072
Add python test for scalarization
jim19930609 Sep 15, 2022
5513ffa
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 15, 2022
e906487
Bug fix
jim19930609 Sep 16, 2022
b176726
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 16, 2022
5b0bbfe
Bug fix
jim19930609 Sep 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.field import Field
from taichi.lang.impl import current_cfg
from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl,
_TiScopeMatrixImpl)
_TiScopeMatrixImpl, make_matrix)
from taichi.lang.snode import append
from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type
from taichi.types import (annotations, ndarray_type, primitive_types,
Expand Down Expand Up @@ -114,6 +115,12 @@ def build_Assign(ctx, node):
@staticmethod
def build_assign_slice(ctx, node_target, values, is_static_assign):
target = ASTTransformer.build_Subscript(ctx, node_target, get_ref=True)
if current_cfg().real_matrix:
if isinstance(node_target.value.ptr,
any_array.AnyArray) and isinstance(
values, (list, tuple)):
values = make_matrix(values)

if isinstance(node_target.value.ptr, Matrix):
if isinstance(node_target.value.ptr._impl, _TiScopeMatrixImpl):
target._assign(values)
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from taichi.lang.kernel_arguments import SparseMatrixProxy
from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType,
Vector, _IntermediateMatrix,
_MatrixFieldElement, make_matrix)
_MatrixFieldElement)
from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance,
MeshRelationAccessProxy,
MeshReorderedMatrixFieldProxy,
Expand Down
1 change: 1 addition & 0 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
serializer(config->experimental_auto_mesh_local);
serializer(config->auto_mesh_local_default_occupacy);
serializer(config->real_matrix);
serializer(config->real_matrix_scalarize);
serializer.finalize();

return serializer.data;
Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ void export_lang(py::module &m) {
&CompileConfig::ndarray_use_cached_allocator)
.def_readwrite("use_mesh", &CompileConfig::use_mesh)
.def_readwrite("real_matrix", &CompileConfig::real_matrix)
.def_readwrite("real_matrix_scalarize",
&CompileConfig::real_matrix_scalarize)
.def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd)
.def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd)
.def_readwrite("quant_opt_store_fusion",
Expand Down
79 changes: 74 additions & 5 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ TLANG_NAMESPACE_BEGIN

class Scalarize : public IRVisitor {
public:
DelayedIRModifier modifier_;

Scalarize(IRNode *node) {
allow_undefined_visitor = true;
invoke_default_visitor = false;
node->accept(this);

modifier_.modify_ir();
}

/*
Expand Down Expand Up @@ -51,18 +55,75 @@ class Scalarize : public IRVisitor {
int num_elements = val_tensor_type->get_num_elements();
for (int i = 0; i < num_elements; i++) {
auto const_stmt = std::make_unique<ConstStmt>(
TypedConstant(stmt->val->ret_type.get_element_type(), i));
TypedConstant(get_data_type<int32>(), i));

auto ptr_offset_stmt =
std::make_unique<PtrOffsetStmt>(stmt->dest, const_stmt.get());
auto scalarized_stmt = std::make_unique<T>(ptr_offset_stmt.get(),
matrix_init_stmt->values[i]);

stmt->insert_before_me(std::move(const_stmt));
stmt->insert_before_me(std::move(ptr_offset_stmt));
stmt->insert_before_me(std::move(scalarized_stmt));
modifier_.insert_before(stmt, std::move(const_stmt));
modifier_.insert_before(stmt, std::move(ptr_offset_stmt));
modifier_.insert_before(stmt, std::move(scalarized_stmt));
}
modifier_.erase(stmt);
}
}

/*

Before:
TensorType<4 x i32> val = LoadStmt(TensorType<4 x i32>* src)

After:
i32* addr0 = PtrOffsetStmt(TensorType<4 x i32>* src, 0)
i32* addr1 = PtrOffsetStmt(TensorType<4 x i32>* src, 1)
i32* addr2 = PtrOffsetStmt(TensorType<4 x i32>* src, 2)
i32* addr3 = PtrOffsetStmt(TensorType<4 x i32>* src, 3)

i32 val0 = LoadStmt(addr0)
i32 val1 = LoadStmt(addr1)
i32 val2 = LoadStmt(addr2)
i32 val3 = LoadStmt(addr3)

tmp = MatrixInitStmt(val0, val1, val2, val3)

stmt->replace_all_usages_with(tmp)
*/
template <typename T>
void scalarize_load_stmt(T *stmt) {
auto src_dtype = stmt->src->ret_type.ptr_removed();
if (src_dtype->template is<TensorType>()) {
// Needs scalarize
auto src_tensor_type = src_dtype->template as<TensorType>();

std::vector<Stmt *> matrix_init_values;
int num_elements = src_tensor_type->get_num_elements();

for (size_t i = 0; i < num_elements; i++) {
auto const_stmt = std::make_unique<ConstStmt>(
TypedConstant(get_data_type<int32>(), i));

auto ptr_offset_stmt =
std::make_unique<PtrOffsetStmt>(stmt->src, const_stmt.get());
auto scalarized_stmt = std::make_unique<T>(ptr_offset_stmt.get());

matrix_init_values.push_back(scalarized_stmt.get());

modifier_.insert_before(stmt, std::move(const_stmt));
modifier_.insert_before(stmt, std::move(ptr_offset_stmt));
modifier_.insert_before(stmt, std::move(scalarized_stmt));
}
stmt->parent->erase(stmt);

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);

matrix_init_stmt->ret_type = src_dtype;

stmt->replace_usages_with(matrix_init_stmt.get());
modifier_.insert_before(stmt, std::move(matrix_init_stmt));

modifier_.erase(stmt);
}
}

Expand Down Expand Up @@ -107,6 +168,14 @@ class Scalarize : public IRVisitor {
void visit(LocalStoreStmt *stmt) override {
scalarize_store_stmt<LocalStoreStmt>(stmt);
}

void visit(GlobalLoadStmt *stmt) override {
scalarize_load_stmt<GlobalLoadStmt>(stmt);
}

void visit(LocalLoadStmt *stmt) override {
scalarize_load_stmt<LocalLoadStmt>(stmt);
}
};

namespace irpass {
Expand Down
66 changes: 63 additions & 3 deletions tests/cpp/transforms/scalarize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace lang {

// Basic tests within a basic block
template <typename T>
void test_scalarize() {
void test_store_scalarize() {
TestProgram test_prog;
test_prog.setup();

Expand Down Expand Up @@ -76,9 +76,69 @@ void test_scalarize() {
EXPECT_EQ(block->statements[16]->is<T>(), true);
}

template <typename T>
void test_load_scalarize() {
TestProgram test_prog;
test_prog.setup();

auto block = std::make_unique<Block>();

auto func = []() {};
auto kernel =
std::make_unique<Kernel>(*test_prog.prog(), func, "fake_kernel");
block->kernel = kernel.get();

auto &type_factory = TypeFactory::get_instance();

/*
TensorType<4 x i32>* %1 = ExternalPtrStmt()
TensorType<4 x i32> %2 = LoadStmt(%1)
*/
Type *tensor_type = type_factory.get_tensor_type(
{2, 2}, type_factory.get_primitive_type(PrimitiveTypeID::i32));
auto argload_stmt = block->push_back<ArgLoadStmt>(0 /*arg_id*/, tensor_type);

std::vector<Stmt *> indices = {};
Stmt *src_stmt = block->push_back<ExternalPtrStmt>(
argload_stmt, indices); // fake ExternalPtrStmt
src_stmt->ret_type = type_factory.get_pointer_type(tensor_type);

block->push_back<T>(src_stmt);

irpass::scalarize(block.get());

EXPECT_EQ(block->size(), 1 /*argload*/ + 1 /*external_ptr*/ + 4 /*const*/ +
4 /*ptroffset*/ + 4 /*load*/ +
1 /*matrix_init*/);

// Check for scalarized statements
EXPECT_EQ(block->statements[2]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[3]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[4]->is<T>(), true);

EXPECT_EQ(block->statements[5]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[6]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[7]->is<T>(), true);

EXPECT_EQ(block->statements[8]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[9]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[10]->is<T>(), true);

EXPECT_EQ(block->statements[11]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[12]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[13]->is<T>(), true);

EXPECT_EQ(block->statements[14]->is<MatrixInitStmt>(), true);
}

TEST(Scalarize, ScalarizeStore) {
test_scalarize<GlobalStoreStmt>();
test_scalarize<LocalStoreStmt>();
test_store_scalarize<GlobalStoreStmt>();
test_store_scalarize<LocalStoreStmt>();
}

TEST(Scalarize, ScalarizeLoad) {
test_load_scalarize<GlobalLoadStmt>();
test_load_scalarize<LocalLoadStmt>();
}

} // namespace lang
Expand Down
38 changes: 38 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,3 +817,41 @@ def foo() -> ti.types.matrix(2, 2, ti.f32):
return a @ b.transpose()

assert foo() == [[1.0, 2.0], [2.0, 4.0]]


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True)
def test_store_scalarize():
@ti.kernel
def func(a: ti.types.ndarray()):
for i in range(5):
a[i] = [[i, i + 1], [i + 2, i + 3]]

x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5)
func(x)

assert (x[0] == [[0, 1], [2, 3]])
assert (x[1] == [[1, 2], [3, 4]])
assert (x[2] == [[2, 3], [4, 5]])
assert (x[3] == [[3, 4], [5, 6]])
assert (x[4] == [[4, 5], [6, 7]])


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True)
def test_load_store_scalarize():
@ti.kernel
def func(a: ti.types.ndarray()):
for i in range(3):
a[i] = [[i, i + 1], [i + 2, i + 3]]

a[3] = a[1]
a[4] = a[2]

x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5)
func(x)

assert (x[3] == [[1, 2], [3, 4]])
assert (x[4] == [[2, 3], [4, 5]])