Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] [IR] Kernel scalar return support (ArgStoreStmt -> KernelReturnStmt) #917

Merged
merged 25 commits into from
May 8, 2020
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
@@ -758,7 +758,7 @@ def visit_Return(self, node):
ret_expr.args[0].args[0] = node.value
ret_expr.args[1] = self.returns
ret_stmt = self.parse_stmt(
'ti.core.create_kernel_return(ret.ptr)')
'ti.core.create_kernel_return(ret.ptr)')
ret_stmt.value.args[0].value = ret_expr
return ret_stmt
return node
3 changes: 1 addition & 2 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
@@ -503,8 +503,7 @@ class KernelGen : public IRVisitor {
used.int64 = true;
// TODO: consider use _rets_{}_ instead of _args_{}_
// TODO: use stmt->ret_id instead of 0 as index
emit("_args_{}_[0] = {};",
data_type_short_name(stmt->element_type()),
emit("_args_{}_[0] = {};", data_type_short_name(stmt->element_type()),
stmt->value->short_name());
}

5 changes: 3 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
@@ -885,8 +885,9 @@ void CodeGenLLVM::visit(KernelReturnStmt *stmt) {
if (stmt->is_ptr) {
TI_NOT_IMPLEMENTED
} else {
auto intermediate_bits = tlctx->get_data_type(stmt->value->ret_type.data_type)
->getPrimitiveSizeInBits();
auto intermediate_bits =
tlctx->get_data_type(stmt->value->ret_type.data_type)
->getPrimitiveSizeInBits();
llvm::Type *intermediate_type =
llvm::Type::getIntNTy(*llvm_context, intermediate_bits);
llvm::Type *dest_ty = tlctx->get_data_type<int64>();
9 changes: 4 additions & 5 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
@@ -143,16 +143,15 @@ void Kernel::set_arg_int(int i, int64 d) {
}

// NOTE: sync with snode.cpp: fetch_reader_result
static uint64 fetch_result_uint64(int i)
{
static uint64 fetch_result_uint64(int i) {
uint64 ret;
auto arch = get_current_program().config.arch;
if (arch == Arch::cuda) {
// TODO: refactor
// XXX: what about unified memory?
#if defined(TI_WITH_CUDA)
CUDADriver::get_instance().memcpy_device_to_host(&ret,
(uint64 *)get_current_program().result_buffer + i,
CUDADriver::get_instance().memcpy_device_to_host(
&ret, (uint64 *)get_current_program().result_buffer + i,
sizeof(uint64));
#else
TI_NOT_IMPLEMENTED;
@@ -166,7 +165,7 @@ static uint64 fetch_result_uint64(int i)
}

template <typename T>
static T fetch_result(int i) // TODO: move to Program::fetch_result
static T fetch_result(int i) // TODO: move to Program::fetch_result
{
return taichi_union_cast_with_different_sizes<T>(fetch_result_uint64(i));
}
3 changes: 1 addition & 2 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
@@ -39,8 +39,7 @@ class Kernel {
struct Ret {
DataType dt;

explicit Ret(DataType dt = DataType::unknown)
: dt(dt) {
explicit Ret(DataType dt = DataType::unknown) : dt(dt) {
}
};

2 changes: 1 addition & 1 deletion taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
@@ -342,7 +342,7 @@ class TypeCheck : public IRVisitor {
}
auto &rets = current_kernel->rets;
TI_ASSERT(rets.size() >= 1);
auto ret = rets[0]; // TODO: stmt->ret_id?
auto ret = rets[0]; // TODO: stmt->ret_id?
auto ret_type = ret.dt;
TI_ASSERT(stmt->value->ret_type.data_type == ret_type);
stmt->ret_type = VectorType(1, ret_type);