Skip to content

Commit

Permalink
Addressed review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 committed Sep 29, 2022
1 parent 3ab0382 commit 47675b6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 86 deletions.
53 changes: 15 additions & 38 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,7 @@

namespace taichi::lang {

class BasicVisitor : public IRVisitor {
public:
void visit(Block *stmt_list) override {
for (auto &stmt : stmt_list->statements) {
stmt->accept(this);
}
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->true_statements)
if_stmt->true_statements->accept(this);
if (if_stmt->false_statements) {
if_stmt->false_statements->accept(this);
}
}

void visit(WhileStmt *stmt) override {
stmt->body->accept(this);
}

void visit(RangeForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(StructForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(MeshForStmt *for_stmt) override {
for_stmt->body->accept(this);
}
};

class Scalarize : public BasicVisitor {
class Scalarize : public BasicStmtVisitor {
public:
DelayedIRModifier modifier_;

Expand Down Expand Up @@ -85,13 +52,16 @@ class Scalarize : public BasicVisitor {
auto matrix_init_stmt = stmt->val->template as<MatrixInitStmt>();

int num_elements = val_tensor_type->get_num_elements();

auto primitive_type = dest_tensor_type->get_element_type();
for (int i = 0; i < num_elements; i++) {
auto const_stmt = std::make_unique<ConstStmt>(
TypedConstant(get_data_type<int32>(), i));

auto ptr_offset_stmt =
std::make_unique<MatrixPtrStmt>(stmt->dest, const_stmt.get());
ptr_offset_stmt->ret_type = primitive_type;
ptr_offset_stmt->ret_type.set_is_pointer(true);

auto scalarized_stmt = std::make_unique<T>(ptr_offset_stmt.get(),
matrix_init_stmt->values[i]);

Expand Down Expand Up @@ -132,13 +102,18 @@ class Scalarize : public BasicVisitor {
std::vector<Stmt *> matrix_init_values;
int num_elements = src_tensor_type->get_num_elements();

auto primitive_type = src_tensor_type->get_element_type();
for (size_t i = 0; i < num_elements; i++) {
auto const_stmt = std::make_unique<ConstStmt>(
TypedConstant(get_data_type<int32>(), i));

auto ptr_offset_stmt =
std::make_unique<MatrixPtrStmt>(stmt->src, const_stmt.get());
ptr_offset_stmt->ret_type = primitive_type;
ptr_offset_stmt->ret_type.set_is_pointer(true);

auto scalarized_stmt = std::make_unique<T>(ptr_offset_stmt.get());
scalarized_stmt->ret_type = primitive_type;

matrix_init_values.push_back(scalarized_stmt.get());

Expand Down Expand Up @@ -263,11 +238,13 @@ class Scalarize : public BasicVisitor {
TI_ASSERT(rhs_vals.size() == lhs_vals.size());

size_t num_elements = lhs_vals.size();
auto primitive_type = stmt->ret_type.get_element_type();
std::vector<Stmt *> matrix_init_values;
for (size_t i = 0; i < num_elements; i++) {
auto binary_stmt = std::make_unique<BinaryOpStmt>(
stmt->op_type, lhs_vals[i], rhs_vals[i]);
matrix_init_values.push_back(binary_stmt.get());
binary_stmt->ret_type = primitive_type;

modifier_.insert_before(stmt, std::move(binary_stmt));
}
Expand Down Expand Up @@ -300,10 +277,10 @@ class Scalarize : public BasicVisitor {
}

private:
using BasicVisitor::visit;
using BasicStmtVisitor::visit;
};

class ScalarizePointers : public BasicVisitor {
class ScalarizePointers : public BasicStmtVisitor {
public:
DelayedIRModifier modifier_;

Expand Down Expand Up @@ -400,7 +377,7 @@ class ScalarizePointers : public BasicVisitor {
}

private:
using BasicVisitor::visit;
using BasicStmtVisitor::visit;
};

namespace irpass {
Expand Down
90 changes: 42 additions & 48 deletions tests/cpp/transforms/scalarize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

namespace taichi::lang {

template <typename T>
void test_load_scalarize() {
TEST(Scalarize, ScalarizeGlobalStore) {
// Basic tests within a basic block
TestProgram test_prog;
test_prog.setup();

Expand All @@ -22,48 +22,55 @@ void test_load_scalarize() {

/*
TensorType<4 x i32>* %1 = ExternalPtrStmt()
TensorType<4 x i32> %2 = LoadStmt(%1)
TensorType<4 x i32> %2 = MatrixInitStmt([1, 1, 2, 2])
StoreStmt(%1, %2)
*/
Type *tensor_type = type_factory.get_tensor_type(
{2, 2}, type_factory.get_primitive_type(PrimitiveTypeID::i32));
auto const_1_stmt = block->push_back<ConstStmt>(TypedConstant(1));
auto const_2_stmt = block->push_back<ConstStmt>(TypedConstant(2));
auto argload_stmt = block->push_back<ArgLoadStmt>(0 /*arg_id*/, tensor_type);

std::vector<Stmt *> indices = {};
Stmt *src_stmt = block->push_back<ExternalPtrStmt>(
Stmt *dest_stmt = block->push_back<ExternalPtrStmt>(
argload_stmt, indices); // fake ExternalPtrStmt
src_stmt->ret_type = type_factory.get_pointer_type(tensor_type);

block->push_back<T>(src_stmt);
dest_stmt->ret_type = type_factory.get_pointer_type(tensor_type);

std::vector<Stmt *> matrix_init_vals = {const_1_stmt, const_1_stmt,
const_2_stmt, const_2_stmt};
auto matrix_init_stmt =
block->push_back<MatrixInitStmt>(std::move(matrix_init_vals));
matrix_init_stmt->ret_type = tensor_type;

block->push_back<GlobalStoreStmt>(dest_stmt, matrix_init_stmt);

irpass::scalarize(block.get());
irpass::die(block.get());

EXPECT_EQ(block->size(), 1 /*argload*/ + 1 /*external_ptr*/ + 4 /*const*/ +
4 /*matrix_ptr*/ + 4 /*load*/ +
1 /*matrix_init*/);
EXPECT_EQ(block->size(), 2 /*const*/ + 1 /*argload*/ + 1 /*external_ptr*/ +
1 /*matrix_init*/ + 4 /*const*/ +
4 /*matrix_ptr*/ + 4 /*store*/);

// Check for scalarized statements
EXPECT_EQ(block->statements[2]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[3]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[4]->is<T>(), true);

EXPECT_EQ(block->statements[5]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[6]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[7]->is<T>(), true);
EXPECT_EQ(block->statements[7]->is<GlobalStoreStmt>(), true);

EXPECT_EQ(block->statements[8]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[9]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[10]->is<T>(), true);
EXPECT_EQ(block->statements[10]->is<GlobalStoreStmt>(), true);

EXPECT_EQ(block->statements[11]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[12]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[13]->is<T>(), true);
EXPECT_EQ(block->statements[13]->is<GlobalStoreStmt>(), true);

EXPECT_EQ(block->statements[14]->is<MatrixInitStmt>(), true);
EXPECT_EQ(block->statements[14]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[15]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[16]->is<GlobalStoreStmt>(), true);
}

TEST(Scalarize, ScalarizeGlobalStore) {
// Basic tests within a basic block
TEST(Scalarize, ScalarizeGlobalLoad) {
TestProgram test_prog;
test_prog.setup();

Expand All @@ -78,52 +85,44 @@ TEST(Scalarize, ScalarizeGlobalStore) {

/*
TensorType<4 x i32>* %1 = ExternalPtrStmt()
TensorType<4 x i32> %2 = MatrixInitStmt([1, 1, 2, 2])
StoreStmt(%1, %2)
TensorType<4 x i32> %2 = LoadStmt(%1)
*/
Type *tensor_type = type_factory.get_tensor_type(
{2, 2}, type_factory.get_primitive_type(PrimitiveTypeID::i32));
auto const_1_stmt = block->push_back<ConstStmt>(TypedConstant(1));
auto const_2_stmt = block->push_back<ConstStmt>(TypedConstant(2));
auto argload_stmt = block->push_back<ArgLoadStmt>(0 /*arg_id*/, tensor_type);

std::vector<Stmt *> indices = {};
Stmt *dest_stmt = block->push_back<ExternalPtrStmt>(
Stmt *src_stmt = block->push_back<ExternalPtrStmt>(
argload_stmt, indices); // fake ExternalPtrStmt
src_stmt->ret_type = type_factory.get_pointer_type(tensor_type);

dest_stmt->ret_type = type_factory.get_pointer_type(tensor_type);

std::vector<Stmt *> matrix_init_vals = {const_1_stmt, const_1_stmt,
const_2_stmt, const_2_stmt};
auto matrix_init_stmt =
block->push_back<MatrixInitStmt>(std::move(matrix_init_vals));
matrix_init_stmt->ret_type = tensor_type;

block->push_back<GlobalStoreStmt>(dest_stmt, matrix_init_stmt);
block->push_back<GlobalLoadStmt>(src_stmt);

irpass::scalarize(block.get());
irpass::die(block.get());

EXPECT_EQ(block->size(), 2 /*const*/ + 1 /*argload*/ + 1 /*external_ptr*/ +
1 /*matrix_init*/ + 4 /*const*/ +
4 /*matrix_ptr*/ + 4 /*store*/);
EXPECT_EQ(block->size(), 1 /*argload*/ + 1 /*external_ptr*/ + 4 /*const*/ +
4 /*matrix_ptr*/ + 4 /*load*/ +
1 /*matrix_init*/);

// Check for scalarized statements
EXPECT_EQ(block->statements[2]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[3]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[4]->is<GlobalLoadStmt>(), true);

EXPECT_EQ(block->statements[5]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[6]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[7]->is<GlobalStoreStmt>(), true);
EXPECT_EQ(block->statements[7]->is<GlobalLoadStmt>(), true);

EXPECT_EQ(block->statements[8]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[9]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[10]->is<GlobalStoreStmt>(), true);
EXPECT_EQ(block->statements[10]->is<GlobalLoadStmt>(), true);

EXPECT_EQ(block->statements[11]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[12]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[13]->is<GlobalStoreStmt>(), true);
EXPECT_EQ(block->statements[13]->is<GlobalLoadStmt>(), true);

EXPECT_EQ(block->statements[14]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[15]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[16]->is<GlobalStoreStmt>(), true);
EXPECT_EQ(block->statements[14]->is<MatrixInitStmt>(), true);
}

TEST(Scalarize, ScalarizeLocalStore) {
Expand Down Expand Up @@ -182,7 +181,7 @@ TEST(Scalarize, ScalarizeLocalStore) {
EXPECT_EQ(block->statements[10]->is<LocalStoreStmt>(), true);
}

TEST(Scalarize, ScalarizeLoadAlloca) {
TEST(Scalarize, ScalarizeLocalLoad) {
// Basic tests within a basic block
TestProgram test_prog;
test_prog.setup();
Expand Down Expand Up @@ -226,9 +225,4 @@ TEST(Scalarize, ScalarizeLoadAlloca) {
EXPECT_EQ(block->statements[8]->is<MatrixInitStmt>(), true);
}

TEST(Scalarize, ScalarizeLoad) {
test_load_scalarize<GlobalLoadStmt>();
test_load_scalarize<LocalLoadStmt>();
}

} // namespace taichi::lang

0 comments on commit 47675b6

Please sign in to comment.