Skip to content

Commit

Permalink
Addressed review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 committed Sep 14, 2022
1 parent c8aa4d9 commit add09e6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 59 deletions.
11 changes: 11 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ TLANG_NAMESPACE_BEGIN
"[{}] was not type-checked", \
ExpressionHumanFriendlyPrinter::expr_to_string(x))

static bool is_primitive_or_tensor_type(DataType &type) {
return type->is<PrimitiveType>() || type->is<TensorType>();
}

FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type,
SNode *snode,
const ExprGroup &indices,
Expand Down Expand Up @@ -180,10 +184,17 @@ void BinaryOpExpression::type_check(CompileConfig *config) {
binary_op_type_symbol(type), lhs->ret_type->to_string(),
rhs->ret_type->to_string()));
};

if (!is_primitive_or_tensor_type(lhs_type) ||
!is_primitive_or_tensor_type(rhs_type)) {
error();
}

if ((lhs_type->is<PrimitiveType>() && rhs_type->is<TensorType>()) ||
(lhs_type->is<TensorType>() && rhs_type->is<PrimitiveType>())) {
TI_NOT_IMPLEMENTED;
}

if (binary_is_bitwise(type) &&
(!is_integral(lhs_type) || !is_integral(rhs_type)))
error();
Expand Down
1 change: 0 additions & 1 deletion taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ void offload_to_executable(IRNode *ir,
irpass::full_simplify(
ir, config,
{lower_global_access, /*autodiff_enabled*/ false, kernel->program});

print("Simplified IV");

if (determine_ad_stack_size) {
Expand Down
60 changes: 2 additions & 58 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,58 +5,6 @@
#include "taichi/system/profiler.h"

TLANG_NAMESPACE_BEGIN
class MatrixInitRemoval : public IRVisitor {
public:
MatrixInitRemoval(IRNode *node, std::unordered_set<Stmt *> &&remove_list) {
allow_undefined_visitor = true;
invoke_default_visitor = false;
remove_list_ = std::move(remove_list);
node->accept(this);
}

void visit(Block *stmt_list) override {
for (auto &stmt : stmt_list->statements) {
stmt->accept(this);
}
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->true_statements)
if_stmt->true_statements->accept(this);
if (if_stmt->false_statements) {
if_stmt->false_statements->accept(this);
}
}

void visit(WhileStmt *stmt) override {
stmt->body->accept(this);
}

void visit(RangeForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(StructForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(MeshForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(OffloadedStmt *stmt) override {
stmt->all_blocks_accept(this);
}

void visit(MatrixInitStmt *stmt) override {
if (remove_list_.count(stmt)) {
stmt->parent->erase(stmt);
}
}

private:
std::unordered_set<Stmt *> remove_list_;
};

class Scalarize : public IRVisitor {
public:
Expand Down Expand Up @@ -101,7 +49,7 @@ class Scalarize : public IRVisitor {
auto matrix_init_stmt = stmt->val->template as<MatrixInitStmt>();

int num_elements = val_tensor_type->get_num_elements();
for (size_t i = 0; i < num_elements; i++) {
for (int i = 0; i < num_elements; i++) {
auto const_stmt = std::make_unique<ConstStmt>(
TypedConstant(stmt->val->ret_type.get_element_type(), i));

Expand Down Expand Up @@ -159,8 +107,6 @@ class Scalarize : public IRVisitor {
void visit(LocalStoreStmt *stmt) override {
scalarize_store_stmt<LocalStoreStmt>(stmt);
}

std::unordered_set<Stmt *> matrix_init_to_remove_;
};

namespace irpass {
Expand All @@ -169,12 +115,10 @@ void scalarize(IRNode *root) {
TI_AUTO_PROF;
Scalarize scalarize_pass(root);

/*
/* TODO(zhanlue): Remove redundant MatrixInitStmt
Scalarize pass will generate temporary MatrixInitStmts, which are only used
as rvalues. Remove these MatrixInitStmts since it's no longer needed.
*/
MatrixInitRemoval matrix_init_removal_pass(
root, std::move(scalarize_pass.matrix_init_to_remove_));
}

} // namespace irpass
Expand Down

0 comments on commit add09e6

Please sign in to comment.