diff --git a/python/taichi/lang/kernel.py b/python/taichi/lang/kernel.py index 3b9e2e3fe36c3..fd72364409479 100644 --- a/python/taichi/lang/kernel.py +++ b/python/taichi/lang/kernel.py @@ -158,6 +158,7 @@ def __init__(self, func, is_grad, classkernel=False): self.is_grad = is_grad self.arguments = [] self.argument_names = [] + self.return_type = None self.classkernel = classkernel self.extract_arguments() self.template_slot_locations = [] @@ -180,6 +181,8 @@ def reset(self): def extract_arguments(self): sig = inspect.signature(self.func) + if sig.return_annotation not in (inspect._empty, None): + self.return_type = sig.return_annotation params = sig.parameters arg_names = params.keys() for i, arg_name in enumerate(arg_names): @@ -249,6 +252,9 @@ def materialize(self, key=None, args=None, arg_features=None): if isinstance(anno, ast.Name): global_vars[anno.id] = self.arguments[i] + if isinstance(func_body.returns, ast.Name): + global_vars[func_body.returns.id] = self.return_type + if self.is_grad: from .ast_checker import KernelSimplicityASTChecker KernelSimplicityASTChecker(self.func).visit(tree) @@ -388,12 +394,22 @@ def call_back(): t_kernel() + ret = None + ret_dt = self.return_type + if ret_dt is not None: + if taichi_lang_core.is_integral(ret_dt): + ret = t_kernel.get_ret_int(0) + else: + ret = t_kernel.get_ret_float(0) + if callbacks: import taichi as ti ti.sync() for c in callbacks: c() + return ret + return func__ def match_ext_arr(self, v, needed): @@ -481,7 +497,7 @@ def wrapped(*args, **kwargs): @wraps(func) def wrapped(*args, **kwargs): - primal(*args, **kwargs) + return primal(*args, **kwargs) wrapped.grad = adjoint diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index 226f8385a0006..5a45ff3ae55d7 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -40,3 +40,8 @@ def decl_scalar_arg(dt): def decl_ext_arr_arg(dt, dim): id = taichi_lang_core.decl_arg(dt, True) return Expr(taichi_lang_core.make_external_tensor_expr(dt, dim, id)) + + +def decl_scalar_ret(dt): + id = taichi_lang_core.decl_ret(dt) + return id diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index cbab3a94e7ed7..7e51650e242f3 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -37,6 +37,7 @@ def __init__(self, self.is_classfunc = is_classfunc self.func = func self.arg_features = arg_features + self.returns = None def variable_scope(self, *args): return ScopeGuard(self, *args) @@ -594,6 +595,15 @@ def visit_FunctionDef(self, node): "Function definition not allowed in 'ti.kernel'.") # Transform as kernel arg_decls = [] + + # Treat return type + if node.returns is not None: + ret_init = self.parse_stmt('ti.decl_scalar_ret(0)') + ret_init.value.args[0] = node.returns + self.returns = node.returns + arg_decls.append(ret_init) + node.returns = None + for i, arg in enumerate(args.args): if isinstance(self.func.arguments[i], ti.template): continue @@ -620,6 +630,7 @@ def visit_FunctionDef(self, node): arg_decls.append(arg_init) # remove original args node.args.args = [] + else: # ti.func for decorator in node.decorator_list: if (isinstance(decorator, ast.Attribute) @@ -640,8 +651,10 @@ def visit_FunctionDef(self, node): '_by_value__') args.args[i].arg += '_by_value__' arg_decls.append(arg_init) + with self.variable_scope(): self.generic_visit(node) + node.body = arg_decls + node.body return node @@ -736,7 +749,16 @@ def visit_Assert(self, node): def visit_Return(self, node): self.generic_visit(node) if self.is_kernel: - raise TaichiSyntaxError( - '"return" not allowed in \'ti.kernel\'. Please walk around by storing the return result to a global variable.' - ) + # TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not + if node.value is not None: + if self.returns is None: + raise TaichiSyntaxError('kernel with return value must be ' + 'annotated with a return type, e.g. def func() -> ti.f32') + ret_expr = self.parse_expr('ti.cast(ti.Expr(0), 0)') + ret_expr.args[0].args[0] = node.value + ret_expr.args[1] = self.returns + ret_stmt = self.parse_stmt( + 'ti.core.create_kernel_return(ret.ptr)') + ret_stmt.value.args[0].value = ret_expr + return ret_stmt return node diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 74a65a20446b5..5c8806e09377c 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -498,6 +498,15 @@ class KernelGen : public IRVisitor { const_stmt->short_name(), const_stmt->val[0].stringify()); } + void visit(KernelReturnStmt *stmt) override { + used.argument = true; + used.int64 = true; + // TODO: consider use _rets_{}_ instead of _args_{}_ + // TODO: use stmt->ret_id instead of 0 as index + emit("_args_{}_[0] = {};", data_type_short_name(stmt->element_type()), + stmt->value->short_name()); + } + void visit(ArgLoadStmt *stmt) override { const auto dt = opengl_data_type_name(stmt->element_type()); used.argument = true; diff --git a/taichi/backends/opengl/opengl_api.cpp b/taichi/backends/opengl/opengl_api.cpp index f80ceed371961..b6d42eb245e19 100644 --- a/taichi/backends/opengl/opengl_api.cpp +++ b/taichi/backends/opengl/opengl_api.cpp @@ -366,12 +366,13 @@ struct CompiledKernel { struct CompiledProgram::Impl { std::vector> kernels; - int arg_count; + int arg_count, ret_count; std::map ext_arr_map; size_t gtmp_size; Impl(Kernel *kernel, size_t gtmp_size) : gtmp_size(gtmp_size) { arg_count = kernel->args.size(); + ret_count = kernel->rets.size(); for (int i = 0; i < arg_count; i++) { if (kernel->args[i].is_nparray) { ext_arr_map[i] = kernel->args[i].size; @@ -390,7 +391,7 @@ struct CompiledProgram::Impl { void launch(Context &ctx, GLSLLauncher *launcher) const { std::vector iov; - iov.push_back(IOV{ctx.args, arg_count * sizeof(uint64_t)}); + iov.push_back(IOV{ctx.args, std::max(arg_count, ret_count) * sizeof(uint64_t)}); auto gtmp_arr = std::vector(gtmp_size); void *gtmp_base = gtmp_arr.data(); // std::calloc(gtmp_size, 1); iov.push_back(IOV{gtmp_base, gtmp_size}); diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 58d4395ee08eb..4e0072b215e2e 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -881,6 +881,24 @@ void CodeGenLLVM::visit(ArgStoreStmt *stmt) { } } +void CodeGenLLVM::visit(KernelReturnStmt *stmt) { + if (stmt->is_ptr) { + TI_NOT_IMPLEMENTED + } else { + auto intermediate_bits = + tlctx->get_data_type(stmt->value->ret_type.data_type) + ->getPrimitiveSizeInBits(); + llvm::Type *intermediate_type = + llvm::Type::getIntNTy(*llvm_context, intermediate_bits); + llvm::Type *dest_ty = tlctx->get_data_type(); + auto extended = builder->CreateZExt( + builder->CreateBitCast(llvm_val[stmt->value], intermediate_type), + dest_ty); + builder->CreateCall(get_runtime_function("LLVMRuntime_store_result"), + {get_runtime(), extended}); + } +} + void CodeGenLLVM::visit(LocalLoadStmt *stmt) { TI_ASSERT(stmt->width() == 1); llvm_val[stmt] = builder->CreateLoad(llvm_val[stmt->ptr[0].var]); diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 427a9464c5579..811d263f4914e 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -171,6 +171,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(ArgStoreStmt *stmt) override; + void visit(KernelReturnStmt *stmt) override; + void visit(LocalLoadStmt *stmt) override; void visit(LocalStoreStmt *stmt) override; diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index 5ed9cd50731c4..c2c6ef2e44434 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -12,6 +12,7 @@ PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear PER_STATEMENT(FrontendAssertStmt) PER_STATEMENT(FrontendArgStoreStmt) PER_STATEMENT(FrontendFuncDefStmt) +PER_STATEMENT(FrontendKernelReturnStmt) // Middle-end statement @@ -24,6 +25,7 @@ PER_STATEMENT(WhileControlStmt) PER_STATEMENT(ContinueStmt) PER_STATEMENT(FuncBodyStmt) PER_STATEMENT(FuncCallStmt) +PER_STATEMENT(KernelReturnStmt) PER_STATEMENT(ArgLoadStmt) PER_STATEMENT(ExternalPtrStmt) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 4df05e6d2d78f..c6508397a7ab4 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -198,6 +198,20 @@ class FrontendWhileStmt : public Stmt { DEFINE_ACCEPT }; +class FrontendKernelReturnStmt : public Stmt { + public: + Expr value; + + FrontendKernelReturnStmt(const Expr &value) : value(value) { + } + + bool is_container_statement() const override { + return false; + } + + DEFINE_ACCEPT +}; + // Expressions class ArgLoadExpression : public Expression { diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index e80f17fe6ad5e..1e704fd37f3e5 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -1453,6 +1453,22 @@ class FuncCallStmt : public Stmt { DEFINE_ACCEPT }; +class KernelReturnStmt : public Stmt { + public: + Stmt *value; + + KernelReturnStmt(Stmt *value) : value(value) { + TI_STMT_REG_FIELDS; + } + + bool is_container_statement() const override { + return false; + } + + TI_STMT_DEF_FIELDS(value); + DEFINE_ACCEPT +}; + class WhileStmt : public Stmt { public: Stmt *mask; diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index 36be0efd5a14c..cd80d1a56b20a 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -122,25 +122,6 @@ void SNode::write_float(const std::vector &I, float64 val) { (*writer_kernel)(); } -uint64 SNode::fetch_reader_result() { - uint64 ret; - auto arch = get_current_program().config.arch; - if (arch == Arch::cuda) { - // TODO: refactor -#if defined(TI_WITH_CUDA) - CUDADriver::get_instance().memcpy_device_to_host( - &ret, get_current_program().result_buffer, sizeof(uint64)); -#else - TI_NOT_IMPLEMENTED; -#endif - } else if (arch_is_cpu(arch)) { - ret = *(uint64 *)get_current_program().result_buffer; - } else { - ret = get_current_program().context.get_arg_as_uint64(num_active_indices); - } - return ret; -} - float64 SNode::read_float(const std::vector &I) { if (reader_kernel == nullptr) { reader_kernel = &get_current_program().get_snode_reader(this); @@ -149,14 +130,8 @@ float64 SNode::read_float(const std::vector &I) { get_current_program().synchronize(); (*reader_kernel)(); get_current_program().synchronize(); - auto ret = fetch_reader_result(); - if (dt == DataType::f32) { - return taichi_union_cast_with_different_sizes(ret); - } else if (dt == DataType::f64) { - return taichi_union_cast_with_different_sizes(ret); - } else { - TI_NOT_IMPLEMENTED - } + auto ret = reader_kernel->get_ret_float(0); + return ret; } // for int32 and int64 @@ -178,26 +153,8 @@ int64 SNode::read_int(const std::vector &I) { get_current_program().synchronize(); (*reader_kernel)(); get_current_program().synchronize(); - auto ret = fetch_reader_result(); - if (dt == DataType::i32) { - return taichi_union_cast_with_different_sizes(ret); - } else if (dt == DataType::i64) { - return taichi_union_cast_with_different_sizes(ret); - } else if (dt == DataType::i8) { - return taichi_union_cast_with_different_sizes(ret); - } else if (dt == DataType::i16) { - return taichi_union_cast_with_different_sizes(ret); - } else if (dt == DataType::u8) { - return taichi_union_cast_with_different_sizes(ret); - } else if (dt == DataType::u16) { - return taichi_union_cast_with_different_sizes(ret); - } else if (dt == DataType::u32) { - return taichi_union_cast_with_different_sizes(ret); - } else if (dt == DataType::u64) { - return taichi_union_cast_with_different_sizes(ret); - } else { - TI_NOT_IMPLEMENTED - } + auto ret = reader_kernel->get_ret_int(0); + return ret; } uint64 SNode::read_uint(const std::vector &I) { diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index ce833b1b5b41b..1f3cfdde972ce 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -4,6 +4,7 @@ #include "taichi/program/program.h" #include "taichi/program/async_engine.h" #include "taichi/codegen/codegen.h" +#include "taichi/backends/cuda/cuda_driver.h" TLANG_NAMESPACE_BEGIN @@ -111,13 +112,9 @@ void Kernel::set_arg_float(int i, float64 d) { } } -void Kernel::set_extra_arg_int(int i, int j, int32 d) { - program.context.extra_args[i][j] = d; -} - void Kernel::set_arg_int(int i, int64 d) { TI_ASSERT_INFO( - args[i].is_nparray == false, + !args[i].is_nparray, "Assigning scalar value to numpy array argument is not allowed"); auto dt = args[i].dt; if (dt == DataType::i32) { @@ -145,10 +142,68 @@ void Kernel::set_arg_int(int i, int64 d) { } } +float64 Kernel::get_ret_float(int i) { + auto dt = rets[i].dt; + if (dt == DataType::f32) { + return (float64)get_current_program().fetch_result(i); + } else if (dt == DataType::f64) { + return (float64)get_current_program().fetch_result(i); + } else if (dt == DataType::i32) { + return (float64)get_current_program().fetch_result(i); + } else if (dt == DataType::i64) { + return (float64)get_current_program().fetch_result(i); + } else if (dt == DataType::i8) { + return (float64)get_current_program().fetch_result(i); + } else if (dt == DataType::i16) { + return (float64)get_current_program().fetch_result(i); + } else if (dt == DataType::u8) { + return (float64)get_current_program().fetch_result(i); + } else if (dt == DataType::u16) { + return (float64)get_current_program().fetch_result(i); + } else if (dt == DataType::u32) { + return (float64)get_current_program().fetch_result(i); + } else if (dt == DataType::u64) { + return (float64)get_current_program().fetch_result(i); + } else { + TI_NOT_IMPLEMENTED + } +} + +int64 Kernel::get_ret_int(int i) { + auto dt = rets[i].dt; + if (dt == DataType::i32) { + return (int64)get_current_program().fetch_result(i); + } else if (dt == DataType::i64) { + return (int64)get_current_program().fetch_result(i); + } else if (dt == DataType::i8) { + return (int64)get_current_program().fetch_result(i); + } else if (dt == DataType::i16) { + return (int64)get_current_program().fetch_result(i); + } else if (dt == DataType::u8) { + return (int64)get_current_program().fetch_result(i); + } else if (dt == DataType::u16) { + return (int64)get_current_program().fetch_result(i); + } else if (dt == DataType::u32) { + return (int64)get_current_program().fetch_result(i); + } else if (dt == DataType::u64) { + return (int64)get_current_program().fetch_result(i); + } else if (dt == DataType::f32) { + return (int64)get_current_program().fetch_result(i); + } else if (dt == DataType::f64) { + return (int64)get_current_program().fetch_result(i); + } else { + TI_NOT_IMPLEMENTED + } +} + void Kernel::mark_arg_return_value(int i, bool is_return) { args[i].is_return_value = is_return; } +void Kernel::set_extra_arg_int(int i, int j, int32 d) { + program.context.extra_args[i][j] = d; +} + void Kernel::set_arg_nparray(int i, uint64 ptr, uint64 size) { TI_ASSERT_INFO(args[i].is_nparray, "Assigning numpy array to scalar argument is not allowed"); @@ -166,4 +221,9 @@ int Kernel::insert_arg(DataType dt, bool is_nparray) { return args.size() - 1; } +int Kernel::insert_ret(DataType dt) { + rets.push_back(Ret{dt}); + return rets.size() - 1; +} + TLANG_NAMESPACE_END diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index f5aabb18893d9..55c2fd2b3c934 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -35,7 +35,16 @@ class Kernel { is_return_value(is_return_value) { } }; + + struct Ret { + DataType dt; + + explicit Ret(DataType dt = DataType::unknown) : dt(dt) { + } + }; + std::vector args; + std::vector rets; bool is_accessor; bool grad; @@ -56,10 +65,16 @@ class Kernel { int insert_arg(DataType dt, bool is_nparray); + int insert_ret(DataType dt); + void set_arg_float(int i, float64 d); void set_arg_int(int i, int64 d); + float64 get_ret_float(int i); + + int64 get_ret_int(int i); + void set_extra_arg_int(int i, int j, int32 d); void mark_arg_return_value(int i, bool is_return = true); diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index cd753ccd02254..b24c85741b519 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -449,8 +449,8 @@ Kernel &Program::get_snode_reader(SNode *snode) { for (int i = 0; i < snode->num_active_indices; i++) { indices.push_back(Expr::make(i)); } - auto ret = Stmt::make( - snode->num_active_indices, load_if_ptr((snode->expr)[indices])); + auto ret = Stmt::make( + load_if_ptr((snode->expr)[indices])); current_ast_builder().insert(std::move(ret)); }); ker.set_arch(get_snode_accessor_arch()); @@ -458,8 +458,7 @@ Kernel &Program::get_snode_reader(SNode *snode) { ker.is_accessor = true; for (int i = 0; i < snode->num_active_indices; i++) ker.insert_arg(DataType::i32, false); - auto ret_val = ker.insert_arg(snode->dt, false); - ker.mark_arg_return_value(ret_val); + ker.insert_ret(snode->dt); return ker; } @@ -483,6 +482,26 @@ Kernel &Program::get_snode_writer(SNode *snode) { return ker; } +uint64 Program::fetch_result_uint64(int i) { + uint64 ret; + auto arch = config.arch; + if (arch == Arch::cuda) { + // TODO: refactor + // We use a `memcpy_device_to_host` call here even if we have unified memory. This simplifies code. Also note that a unified memory (4KB) page fault is rather expensive for reading 4-8 bytes. +#if defined(TI_WITH_CUDA) + CUDADriver::get_instance().memcpy_device_to_host( + &ret, (uint64 *)result_buffer + i, sizeof(uint64)); +#else + TI_NOT_IMPLEMENTED; +#endif + } else if (arch_is_cpu(arch)) { + ret = ((uint64 *)result_buffer)[i]; + } else { + ret = context.get_arg_as_uint64(i); + } + return ret; +} + void Program::finalize() { synchronize(); TI_TRACE("Program finalizing..."); diff --git a/taichi/program/program.h b/taichi/program/program.h index 345fae8723007..5dad57696b8fc 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -159,6 +159,13 @@ class Program { Kernel &get_snode_writer(SNode *snode); + uint64 fetch_result_uint64(int i); + + template + T fetch_result(int i) { + return taichi_union_cast_with_different_sizes(fetch_result_uint64(i)); + } + Arch get_host_arch() { return host_arch(); } diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index c9eceeac17d22..5927cc941b965 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -193,9 +193,11 @@ void export_lang(py::module &m) { py::class_(m, "Kernel") .def("set_arg_int", &Kernel::set_arg_int) - .def("set_extra_arg_int", &Kernel::set_extra_arg_int) .def("set_arg_float", &Kernel::set_arg_float) .def("set_arg_nparray", &Kernel::set_arg_nparray) + .def("set_extra_arg_int", &Kernel::set_extra_arg_int) + .def("get_ret_int", &Kernel::get_ret_int) + .def("get_ret_float", &Kernel::get_ret_float) .def("__call__", [](Kernel *kernel) { py::gil_scoped_release release; kernel->operator()(); @@ -305,6 +307,10 @@ void export_lang(py::module &m) { current_ast_builder().insert(Stmt::make()); }); + m.def("create_kernel_return", [&](const Expr &value) { + current_ast_builder().insert(Stmt::make(value)); + }); + m.def("insert_continue_stmt", [&]() { current_ast_builder().insert(Stmt::make()); }); @@ -491,6 +497,10 @@ void export_lang(py::module &m) { is_nparray); }); + m.def("decl_ret", [&](DataType dt) { + return get_current_program().get_current_kernel().insert_ret(dt); + }); + m.def("test_throw", [] { try { throw IRModified(); diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 9f39547b354bc..5c51e3ee45e91 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -319,6 +319,16 @@ class IRPrinter : public IRVisitor { stmt->arg_id, stmt->val->name()); } + void visit(FrontendKernelReturnStmt *stmt) override { + print("{}{} : kernel return {}", stmt->type_hint(), stmt->name(), + stmt->value->serialize()); + } + + void visit(KernelReturnStmt *stmt) override { + print("{}{} : kernel return {}", stmt->type_hint(), stmt->name(), + stmt->value->name()); + } + void visit(LocalLoadStmt *stmt) override { print("{}{} = local load [{}]", stmt->type_hint(), stmt->name(), to_string(stmt->ptr)); diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 27bb19bb3da82..d22a49a15f97c 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -284,6 +284,15 @@ class LowerAST : public IRVisitor { capturing_loop = old_capturing_loop; } + void visit(FrontendKernelReturnStmt *stmt) override { + auto expr = stmt->value; + auto fctx = make_flatten_ctx(); + expr->flatten(&fctx); + fctx.push_back(fctx.back_stmt()); + stmt->parent->replace_with(stmt, std::move(fctx.stmts)); + throw IRModified(); + } + void visit(FrontendEvalStmt *stmt) override { // expand rhs auto expr = stmt->expr; diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 0f31e4dfce5ec..0ade72fc57888 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -335,6 +335,19 @@ class TypeCheck : public IRVisitor { stmt->ret_type = VectorType(1, arg_type); } + void visit(KernelReturnStmt *stmt) { + Kernel *current_kernel = kernel; + if (current_kernel == nullptr) { + current_kernel = &get_current_program().get_current_kernel(); + } + auto &rets = current_kernel->rets; + TI_ASSERT(rets.size() >= 1); + auto ret = rets[0]; // TODO: stmt->ret_id? + auto ret_type = ret.dt; + TI_ASSERT(stmt->value->ret_type.data_type == ret_type); + stmt->ret_type = VectorType(1, ret_type); + } + void visit(ExternalPtrStmt *stmt) { stmt->ret_type.set_is_pointer(true); stmt->ret_type = VectorType(stmt->base_ptrs.size(), diff --git a/tests/python/test_return.py b/tests/python/test_return.py index 5306dcdb6a5c1..71e383ec275a7 100644 --- a/tests/python/test_return.py +++ b/tests/python/test_return.py @@ -1,10 +1,67 @@ import taichi as ti +from taichi import approx @ti.must_throw(ti.TaichiSyntaxError) -def test_return_in_kernel(): +def _test_return_not_last_stmt(): # TODO: make this work + x = ti.var(ti.i32, ()) + + @ti.kernel + def kernel() -> ti.i32: + return 1 + x[None] = 233 + + kernel() + + +@ti.must_throw(ti.TaichiSyntaxError) +def test_return_without_type_hint(): + @ti.kernel def kernel(): return 1 kernel() + + +def test_const_func_ret(): + + @ti.kernel + def func1() -> ti.f32: + return 3 + + @ti.kernel + def func2() -> ti.i32: + return 3.3 # return type mismatch, will be auto-casted into ti.i32 + + assert func1() == approx(3) + assert func2() == 3 + + +@ti.all_archs +def _test_binary_func_ret(dt1, dt2, dt3, castor): + + @ti.kernel + def func(a: dt1, b: dt2) -> dt3: + return a * b + + if ti.core.is_integral(dt1): + xs = list(range(4)) + else: + xs = [0.2, 0.4, 0.8, 1.0] + + if ti.core.is_integral(dt2): + ys = list(range(4)) + else: + ys = [0.2, 0.4, 0.8, 1.0] + + for x, y in zip(xs, ys): + assert func(x, y) == approx(castor(x * y)) + + +def test_binary_func_ret(): + _test_binary_func_ret(ti.i32, ti.f32, ti.f32, float) + _test_binary_func_ret(ti.f32, ti.i32, ti.f32, float) + _test_binary_func_ret(ti.i32, ti.f32, ti.i32, int) + _test_binary_func_ret(ti.f32, ti.i32, ti.i32, int) +