Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Rename KernelReturnStmt to ReturnStmt #2349

Merged
merged 9 commits into from
May 18, 2021
2 changes: 1 addition & 1 deletion taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class CCTransformer : public IRVisitor {
}
}

void visit(KernelReturnStmt *stmt) override {
void visit(ReturnStmt *stmt) override {
emit("ti_ctx->args[0].val_{} = {};", data_type_name(stmt->element_type()),
stmt->value->raw_name());
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ class KernelCodegen : public IRVisitor {
}
}

void visit(KernelReturnStmt *stmt) override {
void visit(ReturnStmt *stmt) override {
// TODO: use stmt->ret_id instead of 0 as index
emit("*{}.ret0() = {};", kContextVarName, stmt->value->raw_name());
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ class KernelGen : public IRVisitor {
const_stmt->val[0].stringify());
}

void visit(KernelReturnStmt *stmt) override {
void visit(ReturnStmt *stmt) override {
used.buf_args = true;
// TODO: consider use _rets_{}_ instead of _args_{}_
// TODO: use stmt->ret_id instead of 0 as index
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
}
}

void CodeGenLLVM::visit(KernelReturnStmt *stmt) {
void CodeGenLLVM::visit(ReturnStmt *stmt) {
if (stmt->ret_type.is_pointer()) {
TI_NOT_IMPLEMENTED
} else {
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(ArgLoadStmt *stmt) override;

void visit(KernelReturnStmt *stmt) override;
void visit(ReturnStmt *stmt) override;

void visit(LocalLoadStmt *stmt) override;

Expand Down
4 changes: 2 additions & 2 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ PER_STATEMENT(FrontendEvalStmt)
PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear
PER_STATEMENT(FrontendAssertStmt)
PER_STATEMENT(FrontendFuncDefStmt)
PER_STATEMENT(FrontendKernelReturnStmt)
PER_STATEMENT(FrontendReturnStmt)

// Middle-end statement

Expand All @@ -25,7 +25,7 @@ PER_STATEMENT(WhileControlStmt)
PER_STATEMENT(ContinueStmt)
PER_STATEMENT(FuncBodyStmt)
PER_STATEMENT(FuncCallStmt)
PER_STATEMENT(KernelReturnStmt)
PER_STATEMENT(ReturnStmt)

PER_STATEMENT(ArgLoadStmt)
PER_STATEMENT(ExternalPtrStmt)
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,11 @@ class FrontendWhileStmt : public Stmt {
TI_DEFINE_ACCEPT
};

class FrontendKernelReturnStmt : public Stmt {
class FrontendReturnStmt : public Stmt {
public:
Expr value;

FrontendKernelReturnStmt(const Expr &value) : value(value) {
FrontendReturnStmt(const Expr &value) : value(value) {
}

bool is_container_statement() const override {
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ ArgLoadStmt *IRBuilder::create_arg_load(int arg_id, DataType dt, bool is_ptr) {
return insert(Stmt::make_typed<ArgLoadStmt>(arg_id, dt, is_ptr));
}

KernelReturnStmt *IRBuilder::create_return(Stmt *value) {
return insert(Stmt::make_typed<KernelReturnStmt>(value));
ReturnStmt *IRBuilder::create_return(Stmt *value) {
return insert(Stmt::make_typed<ReturnStmt>(value));
}

UnaryOpStmt *IRBuilder::create_cast(Stmt *value, DataType output_type) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class IRBuilder {
ArgLoadStmt *create_arg_load(int arg_id, DataType dt, bool is_ptr);

// The return value of the kernel.
KernelReturnStmt *create_return(Stmt *value);
ReturnStmt *create_return(Stmt *value);

// Unary operations. Returns the result.
UnaryOpStmt *create_cast(Stmt *value, DataType output_type); // cast by value
Expand Down
6 changes: 3 additions & 3 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -764,13 +764,13 @@ class FuncCallStmt : public Stmt {
};

/**
* Exit the kernel with a return value.
* Exit the kernel or function with a return value.
*/
class KernelReturnStmt : public Stmt {
class ReturnStmt : public Stmt {
public:
Stmt *value;

KernelReturnStmt(Stmt *value) : value(value) {
explicit ReturnStmt(Stmt *value) : value(value) {
TI_STMT_REG_FIELDS;
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ Kernel &Program::get_snode_reader(SNode *snode) {
for (int i = 0; i < snode->num_active_indices; i++) {
indices.push_back(Expr::make<ArgLoadExpression>(i, PrimitiveType::i32));
}
auto ret = Stmt::make<FrontendKernelReturnStmt>(
auto ret = Stmt::make<FrontendReturnStmt>(
load_if_ptr(Expr(snode_to_glb_var_exprs_.at(snode))[indices]));
current_ast_builder().insert(std::move(ret));
});
Expand Down
2 changes: 1 addition & 1 deletion taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ void export_lang(py::module &m) {
});

m.def("create_kernel_return", [&](const Expr &value) {
current_ast_builder().insert(Stmt::make<FrontendKernelReturnStmt>(value));
current_ast_builder().insert(Stmt::make<FrontendReturnStmt>(value));
});

m.def("insert_continue_stmt", [&]() {
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ConstantFold : public BasicStmtVisitor {
oper->cast<UnaryOpStmt>()->cast_type = id.rhs;
}
}
auto ret = Stmt::make<KernelReturnStmt>(oper.get());
auto ret = Stmt::make<ReturnStmt>(oper.get());
current_ast_builder().insert(std::move(lhstmt));
if (id.is_binary)
current_ast_builder().insert(std::move(rhstmt));
Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Inliner : public BasicStmtVisitor {
std::move(inlined_ir->as<Block>()->statements));
} else {
if (irpass::analysis::gather_statements(inlined_ir.get(), [&](Stmt *s) {
return s->is<KernelReturnStmt>();
return s->is<ReturnStmt>();
}).size() > 1) {
TI_WARN(
"Multiple returns in function \"{}\" may not be handled properly.",
Expand All @@ -46,11 +46,11 @@ class Inliner : public BasicStmtVisitor {
Stmt::make<AllocaStmt>(func->rets[0].dt), /*location=*/0);
irpass::replace_and_insert_statements(
inlined_ir.get(),
/*filter=*/[&](Stmt *s) { return s->is<KernelReturnStmt>(); },
/*filter=*/[&](Stmt *s) { return s->is<ReturnStmt>(); },
/*generator=*/
[&](Stmt *s) {
return Stmt::make<LocalStoreStmt>(return_address,
s->as<KernelReturnStmt>()->value);
s->as<ReturnStmt>()->value);
});
modifier.insert_before(stmt,
std::move(inlined_ir->as<Block>()->statements));
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,12 @@ class IRPrinter : public IRVisitor {
print("{}{} = arg[{}]", stmt->type_hint(), stmt->name(), stmt->arg_id);
}

void visit(FrontendKernelReturnStmt *stmt) override {
void visit(FrontendReturnStmt *stmt) override {
print("{}{} : kernel return {}", stmt->type_hint(), stmt->name(),
stmt->value->serialize());
}

void visit(KernelReturnStmt *stmt) override {
void visit(ReturnStmt *stmt) override {
print("{}{} : kernel return {}", stmt->type_hint(), stmt->name(),
stmt->value->name());
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,11 @@ class LowerAST : public IRVisitor {
capturing_loop = old_capturing_loop;
}

void visit(FrontendKernelReturnStmt *stmt) override {
void visit(FrontendReturnStmt *stmt) override {
auto expr = stmt->value;
auto fctx = make_flatten_ctx();
expr->flatten(&fctx);
fctx.push_back<KernelReturnStmt>(fctx.back_stmt());
fctx.push_back<ReturnStmt>(fctx.back_stmt());
stmt->parent->replace_with(stmt, std::move(fctx.stmts));
throw IRModified();
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class TypeCheck : public IRVisitor {
stmt->ret_type.set_is_pointer(stmt->is_ptr);
}

void visit(KernelReturnStmt *stmt) {
void visit(ReturnStmt *stmt) {
// TODO: Support stmt->ret_id?
stmt->ret_type = stmt->value->ret_type;
TI_ASSERT(stmt->ret_type->vector_width() == 1);
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/transforms/binary_op_simplify_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ TEST_F(BinaryOpSimplifyTest, MultiplyPOT) {
auto *bin_op = ir_block->statements[2]->as<BinaryOpStmt>();
EXPECT_EQ(bin_op->op_type, BinaryOpType::bit_shl);
EXPECT_EQ(bin_op->rhs, const_stmt);
EXPECT_TRUE(ir_block->statements[3]->is<KernelReturnStmt>());
EXPECT_EQ(ir_block->statements[3]->as<KernelReturnStmt>()->value, bin_op);
EXPECT_TRUE(ir_block->statements[3]->is<ReturnStmt>());
EXPECT_EQ(ir_block->statements[3]->as<ReturnStmt>()->value, bin_op);
}

} // namespace lang
Expand Down