Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] MatrixNdarray refactor part5: Add scalarization for LocalStoreStmt & GlobalStoreStmt with TensorType #5946

Merged
merged 49 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
6739192
[Lang] MatrixNdarray refactor part0: Support direct TensorType constr…
jim19930609 Aug 24, 2022
613a48d
Fixed minor issue
jim19930609 Aug 25, 2022
ef25863
Fixed CI failures
jim19930609 Aug 25, 2022
46aa538
Minor refactor
jim19930609 Aug 25, 2022
88dfeba
Fixed minor issue
jim19930609 Aug 25, 2022
1267793
[Lang] MatrixNdarray refactor part1: Refactored Taichi kernel argumen…
jim19930609 Aug 25, 2022
cb1c463
Fixed CI failure with Metal backend
jim19930609 Aug 26, 2022
8817413
Addressed review comments
jim19930609 Aug 26, 2022
2d40772
Merge branch 'matrix_ndarray_pr1' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
834cf67
Fixed format issue with clang-tidy
jim19930609 Aug 26, 2022
31235c5
Review comments
jim19930609 Aug 26, 2022
a4aa0f3
Merge branch 'matrix_ndarray_pr1' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
a9fea93
[Lang] MatrixNdarray refactor part2: Remove redundant members in pyth…
jim19930609 Aug 26, 2022
010dc7e
Merge branch 'matrix_ndarray_pr2' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
126676f
Fixed CI failure
jim19930609 Aug 26, 2022
ab97396
Fix CI failures
jim19930609 Aug 26, 2022
76cfae0
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Aug 26, 2022
9f8bfd0
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Aug 26, 2022
b6d8735
Merge branch 'matrix_ndarray_pr2' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
7f36e82
Renamed interface
jim19930609 Aug 26, 2022
a8a15ce
Merge branch 'matrix_ndarray_pr2' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
1bbe024
Minor bug fix
jim19930609 Aug 26, 2022
87f51da
Minor bug fix
jim19930609 Aug 26, 2022
ddd4544
Merge branch 'matrix_ndarray_pr3' of github.com:jim19930609/taichi in…
jim19930609 Aug 26, 2022
73d3fce
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Aug 29, 2022
e56c727
[Lang] MatrixNdarray refactor part3: Enable TensorType for MatrixNdar…
jim19930609 Aug 29, 2022
235d889
Merge branch 'matrix_ndarray_pr3' of github.com:jim19930609/taichi in…
jim19930609 Aug 29, 2022
2fed7a8
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Aug 30, 2022
d42df6c
[Lang] MatrixNdarray refactor part4: Lowered TensorType to CHI IR lev…
jim19930609 Sep 1, 2022
3fe45de
Adjust code generation logic
jim19930609 Sep 1, 2022
f4cb4f3
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 1, 2022
ecf213a
Bug fix
jim19930609 Sep 1, 2022
4856194
[Lang] MatrixNdarray refactor part5: Add scalarization for LocalStore…
jim19930609 Sep 1, 2022
31c04ac
Debuging LLVM15 on windows
jim19930609 Sep 2, 2022
98c9b4d
Fixed LLVM15 issues
jim19930609 Sep 2, 2022
a006b5e
Minor fix
jim19930609 Sep 2, 2022
18cecc2
Merge branch 'matrix_ndarray_pr5' of github.com:jim19930609/taichi in…
jim19930609 Sep 2, 2022
2eed5a0
Fixed compilation issues
jim19930609 Sep 2, 2022
654c169
Merge branch 'matrix_ndarray_pr5' of github.com:jim19930609/taichi in…
jim19930609 Sep 2, 2022
81d298e
Add unit tests
jim19930609 Sep 2, 2022
89c7e07
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 2, 2022
22ec49a
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 6, 2022
829c5fb
Bug fix
jim19930609 Sep 6, 2022
d3b9232
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 8, 2022
1774e9c
Adjust type promotion logics
jim19930609 Sep 9, 2022
f9a7bda
Merge branch 'master' of github.com:taichi-dev/taichi into matrix_nda…
jim19930609 Sep 9, 2022
1d93df4
Turn-off scalarization by default
jim19930609 Sep 10, 2022
c8aa4d9
Minor fix
jim19930609 Sep 11, 2022
add09e6
Addressed review comments
jim19930609 Sep 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def subscript(value, *_indices, skip_reordered=False, get_ref=False):
assert current_cfg().real_matrix is True
assert is_tensor(value.ptr.get_ret_type())

# TODO(zhanlue): Merge _ti_core.subscript and _ti_core.make_index_expr
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
Expand Down
6 changes: 4 additions & 2 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,10 @@ void BinaryOpExpression::type_check(CompileConfig *config) {
binary_op_type_symbol(type), lhs->ret_type->to_string(),
rhs->ret_type->to_string()));
};
if (!lhs_type->is<PrimitiveType>() || !rhs_type->is<PrimitiveType>())
error();
if ((lhs_type->is<PrimitiveType>() && rhs_type->is<TensorType>()) ||
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
(lhs_type->is<TensorType>() && rhs_type->is<PrimitiveType>())) {
TI_NOT_IMPLEMENTED;
}
if (binary_is_bitwise(type) &&
(!is_integral(lhs_type) || !is_integral(rhs_type)))
error();
Expand Down
15 changes: 10 additions & 5 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ PtrOffsetStmt::PtrOffsetStmt(Stmt *origin_input, Stmt *offset_input) {
origin = origin_input;
offset = offset_input;
if (origin->is<AllocaStmt>()) {
TI_ASSERT(origin->cast<AllocaStmt>()->ret_type->is<TensorType>());
auto tensor_type = origin->cast<AllocaStmt>()->ret_type->cast<TensorType>();
TI_ASSERT(
origin->cast<AllocaStmt>()->ret_type.ptr_removed()->is<TensorType>());
auto tensor_type =
origin->cast<AllocaStmt>()->ret_type.ptr_removed()->cast<TensorType>();
element_type() = tensor_type->get_element_type();
element_type().set_is_pointer(true);
} else if (origin->is<GlobalTemporaryStmt>()) {
Expand All @@ -78,9 +80,12 @@ PtrOffsetStmt::PtrOffsetStmt(Stmt *origin_input, Stmt *offset_input) {
} else if (origin->is<GlobalPtrStmt>()) {
element_type() = origin->cast<GlobalPtrStmt>()->ret_type;
} else if (origin->is<ExternalPtrStmt>()) {
TI_ASSERT(origin->cast<ExternalPtrStmt>()->ret_type->is<TensorType>());
auto tensor_type =
origin->cast<ExternalPtrStmt>()->ret_type->cast<TensorType>();
TI_ASSERT(origin->cast<ExternalPtrStmt>()
->ret_type.ptr_removed()
->is<TensorType>());
auto tensor_type = origin->cast<ExternalPtrStmt>()
->ret_type.ptr_removed()
->cast<TensorType>();
element_type() = tensor_type->get_element_type();
element_type().set_is_pointer(true);
} else {
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace irpass {

void re_id(IRNode *root);
void flag_access(IRNode *root);
void scalarize(IRNode *root);
bool die(IRNode *root);
bool simplify(IRNode *root, const CompileConfig &config);
bool cfg_optimization(
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ CompileConfig::CompileConfig() {
ndarray_use_cached_allocator = true;
use_mesh = false;
real_matrix = false;
real_matrix_scalarize = true;

saturating_grid_dim = 0;
max_block_dim = 0;
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct CompileConfig {
bool ndarray_use_cached_allocator;
bool use_mesh;
bool real_matrix;
bool real_matrix_scalarize;
DataType default_fp;
DataType default_ip;
DataType default_up;
Expand Down
8 changes: 8 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ void compile_to_offloads(IRNode *ir,
print("Lowered");
}

irpass::scalarize(ir);
print("Scalarized");

irpass::type_check(ir, config);
print("Typechecked");
irpass::analysis::verify(ir);
Expand Down Expand Up @@ -256,6 +259,7 @@ void offload_to_executable(IRNode *ir,
irpass::full_simplify(
ir, config,
{lower_global_access, /*autodiff_enabled*/ false, kernel->program});

jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
print("Simplified IV");

if (determine_ad_stack_size) {
Expand Down Expand Up @@ -316,6 +320,10 @@ void compile_function(IRNode *ir,
irpass::lower_ast(ir);
print("Lowered");
}

irpass::scalarize(ir);
print("Scalarized");

irpass::lower_access(ir, config, {{}, true});
print("Access lowered");
irpass::analysis::verify(ir);
Expand Down
182 changes: 182 additions & 0 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/system/profiler.h"

TLANG_NAMESPACE_BEGIN
class MatrixInitRemoval : public IRVisitor {
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
public:
MatrixInitRemoval(IRNode *node, std::unordered_set<Stmt *> &&remove_list) {
allow_undefined_visitor = true;
invoke_default_visitor = false;
remove_list_ = std::move(remove_list);
node->accept(this);
}

void visit(Block *stmt_list) override {
for (auto &stmt : stmt_list->statements) {
stmt->accept(this);
}
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->true_statements)
if_stmt->true_statements->accept(this);
if (if_stmt->false_statements) {
if_stmt->false_statements->accept(this);
}
}

void visit(WhileStmt *stmt) override {
stmt->body->accept(this);
}

void visit(RangeForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(StructForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(MeshForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(OffloadedStmt *stmt) override {
stmt->all_blocks_accept(this);
}

void visit(MatrixInitStmt *stmt) override {
if (remove_list_.count(stmt)) {
stmt->parent->erase(stmt);
}
}

private:
std::unordered_set<Stmt *> remove_list_;
};

class Scalarize : public IRVisitor {
public:
Scalarize(IRNode *node) {
allow_undefined_visitor = true;
invoke_default_visitor = false;
node->accept(this);
}

/*
"val" of StoreStmt should have already been replaced by a MatrixInitStmt in
former scalarization.

Before:
StoreStmt(TensorType<4 x i32>* dest, TensorType<4 x i32> val)

After:
addr0 = PtrOffsetStmt(TensorType<4 x i32>* dest, 0)
addr1 = PtrOffsetStmt(TensorType<4 x i32>* dest, 1)
addr2 = PtrOffsetStmt(TensorType<4 x i32>* dest, 2)
addr2 = PtrOffsetStmt(TensorType<4 x i32>* dest, 3)

StoreStmt(i32* addr0, i32 val->cast<MatrixInitStmt>()->val[0])
StoreStmt(i32* addr1, i32 val->cast<MatrixInitStmt>()->val[1])
StoreStmt(i32* addr2, i32 val->cast<MatrixInitStmt>()->val[2])
StoreStmt(i32* addr3, i32 val->cast<MatrixInitStmt>()->val[3])
*/
template <typename T>
void scalarize_store_stmt(T *stmt) {
auto dest_dtype = stmt->dest->ret_type.ptr_removed();
auto val_dtype = stmt->val->ret_type;
if (dest_dtype->template is<TensorType>() &&
val_dtype->template is<TensorType>()) {
// Needs scalarize
auto dest_tensor_type = dest_dtype->template as<TensorType>();
auto val_tensor_type = val_dtype->template as<TensorType>();
TI_ASSERT(dest_tensor_type->get_shape() == val_tensor_type->get_shape());
TI_ASSERT(dest_tensor_type->get_element_type() ==
val_tensor_type->get_element_type());

TI_ASSERT(stmt->val->template is<MatrixInitStmt>());
auto matrix_init_stmt = stmt->val->template as<MatrixInitStmt>();

int num_elements = val_tensor_type->get_num_elements();
for (size_t i = 0; i < num_elements; i++) {
auto const_stmt = std::make_unique<ConstStmt>(
TypedConstant(stmt->val->ret_type.get_element_type(), i));
strongoier marked this conversation as resolved.
Show resolved Hide resolved

auto ptr_offset_stmt =
std::make_unique<PtrOffsetStmt>(stmt->dest, const_stmt.get());
auto scalarized_stmt = std::make_unique<T>(ptr_offset_stmt.get(),
matrix_init_stmt->values[i]);

stmt->insert_before_me(std::move(const_stmt));
stmt->insert_before_me(std::move(ptr_offset_stmt));
stmt->insert_before_me(std::move(scalarized_stmt));
}
stmt->parent->erase(stmt);
}
}

void visit(Block *stmt_list) override {
for (auto &stmt : stmt_list->statements) {
stmt->accept(this);
}
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->true_statements)
if_stmt->true_statements->accept(this);
if (if_stmt->false_statements) {
if_stmt->false_statements->accept(this);
}
}

void visit(WhileStmt *stmt) override {
stmt->body->accept(this);
}

void visit(RangeForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(StructForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(MeshForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(OffloadedStmt *stmt) override {
stmt->all_blocks_accept(this);
}

void visit(GlobalStoreStmt *stmt) override {
scalarize_store_stmt<GlobalStoreStmt>(stmt);
}

void visit(LocalStoreStmt *stmt) override {
scalarize_store_stmt<LocalStoreStmt>(stmt);
}

std::unordered_set<Stmt *> matrix_init_to_remove_;
};

namespace irpass {

void scalarize(IRNode *root) {
TI_AUTO_PROF;
Scalarize scalarize_pass(root);

/*
Scalarize pass will generate temporary MatrixInitStmts, which are only used
as rvalues. Remove these MatrixInitStmts since it's no longer needed.
*/
MatrixInitRemoval matrix_init_removal_pass(
root, std::move(scalarize_pass.matrix_init_to_remove_));
}

} // namespace irpass

TLANG_NAMESPACE_END
85 changes: 85 additions & 0 deletions tests/cpp/transforms/scalarize_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include "gtest/gtest.h"

#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "tests/cpp/program/test_program.h"

namespace taichi {
namespace lang {

// Basic tests within a basic block
template <typename T>
void test_scalarize() {
TestProgram test_prog;
test_prog.setup();

auto block = std::make_unique<Block>();

auto func = []() {};
auto kernel =
std::make_unique<Kernel>(*test_prog.prog(), func, "fake_kernel");
block->kernel = kernel.get();

auto &type_factory = TypeFactory::get_instance();

/*
TensorType<4 x i32>* %1 = ExternalPtrStmt()
TensorType<4 x i32> %2 = MatrixInitStmt([1, 1, 2, 2])
StoreStmt(%1, %2)
*/
Type *tensor_type = type_factory.get_tensor_type(
{2, 2}, type_factory.get_primitive_type(PrimitiveTypeID::i32));
auto const_1_stmt = block->push_back<ConstStmt>(TypedConstant(1));
auto const_2_stmt = block->push_back<ConstStmt>(TypedConstant(2));
auto argload_stmt = block->push_back<ArgLoadStmt>(0 /*arg_id*/, tensor_type);

Stmt *dest_stmt = nullptr;
if (std::is_same<T, GlobalStoreStmt>::value) {
std::vector<Stmt *> indices = {};
dest_stmt = block->push_back<ExternalPtrStmt>(
argload_stmt, indices); // fake ExternalPtrStmt

} else {
dest_stmt = block->push_back<AllocaStmt>(tensor_type);
}
dest_stmt->ret_type = type_factory.get_pointer_type(tensor_type);

std::vector<Stmt *> matrix_init_vals = {const_1_stmt, const_1_stmt,
const_2_stmt, const_2_stmt};
auto matrix_init_stmt =
block->push_back<MatrixInitStmt>(std::move(matrix_init_vals));
matrix_init_stmt->ret_type = tensor_type;

block->push_back<T>(dest_stmt, matrix_init_stmt);

irpass::scalarize(block.get());

EXPECT_EQ(block->size(), 2 /*const*/ + 1 /*argload*/ + 1 /*external_ptr*/ +
1 /*matrix_init*/ + 4 /*const*/ +
4 /*ptroffset*/ + 4 /*store*/);

// Check for scalarized statements
EXPECT_EQ(block->statements[5]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[6]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[7]->is<T>(), true);

EXPECT_EQ(block->statements[8]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[9]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[10]->is<T>(), true);

EXPECT_EQ(block->statements[11]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[12]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[13]->is<T>(), true);

EXPECT_EQ(block->statements[14]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[15]->is<PtrOffsetStmt>(), true);
EXPECT_EQ(block->statements[16]->is<T>(), true);
}

TEST(Scalarize, ScalarizeStore) {
test_scalarize<GlobalStoreStmt>();
test_scalarize<LocalStoreStmt>();
}

} // namespace lang
} // namespace taichi