diff --git a/python/taichi/lang/kernel.py b/python/taichi/lang/kernel.py index 3b9e2e3fe36c34..18ebb6fd2b1711 100644 --- a/python/taichi/lang/kernel.py +++ b/python/taichi/lang/kernel.py @@ -387,6 +387,8 @@ def call_back(): self.runtime.target_tape.insert(self, args) t_kernel() + # todo union_cast from uint64 + ret = taichi_lang_core.fetch_return_result() if callbacks: import taichi as ti @@ -394,6 +396,8 @@ def call_back(): for c in callbacks: c() + return ret + return func__ def match_ext_arr(self, v, needed): @@ -412,7 +416,8 @@ def __call__(self, *args, **kwargs): instance_id, arg_features = self.mapper.lookup(args) key = (self.func, instance_id) self.materialize(key=key, args=args, arg_features=arg_features) - return self.compiled_functions[key](*args) + ret = self.compiled_functions[key](*args) + return ret # For a Taichi class definition like below: @@ -481,7 +486,7 @@ def wrapped(*args, **kwargs): @wraps(func) def wrapped(*args, **kwargs): - primal(*args, **kwargs) + return primal(*args, **kwargs) wrapped.grad = adjoint diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 8275787e188d53..773ec478a7979b 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -309,6 +309,10 @@ void export_lang(py::module &m) { current_ast_builder().insert(Stmt::make(value)); }); + m.def("fetch_return_result", []() { + return get_current_program().context.get_arg_as_uint64(0); + }); + m.def("insert_continue_stmt", [&]() { current_ast_builder().insert(Stmt::make()); });