Skip to content

Commit

Permalink
[Lang] Support matrix return value for real function
Browse files Browse the repository at this point in the history
ghstack-source-id: 82962dee39382925290a60102875946f368263ac
Pull Request resolved: #8194
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Jun 19, 2023
1 parent 1103ae8 commit 45c98dd
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def decl_and_create_variable(annotation, name, arg_features):
def transform_as_kernel():
# Treat return type
if node.returns is not None:
kernel_arguments.decl_ret(ctx.func.return_type, ctx.is_real_function)
kernel_arguments.decl_ret(ctx.func.return_type)
impl.get_runtime().compiling_callable.finalize_rets()

for i, arg in enumerate(args.args):
Expand Down
6 changes: 1 addition & 5 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,10 @@ def decl_rw_texture_arg(num_dimensions, buffer_format, lod, name):
)


def decl_ret(dtype, real_func=False):
def decl_ret(dtype):
if isinstance(dtype, StructType):
dtype = dtype.dtype
if isinstance(dtype, MatrixType):
if real_func:
for i in range(dtype.n * dtype.m):
decl_ret(dtype.dtype)
return
dtype = _ti_core.get_type_factory_instance().get_tensor_type([dtype.n, dtype.m], dtype.dtype)
else:
dtype = cook_dtype(dtype)
Expand Down
2 changes: 2 additions & 0 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def func_call_rvalue(self, key, args):
return Expr(_ti_core.make_get_element_expr(func_call.ptr, (0,)))
if isinstance(self.return_type, StructType):
return self.return_type.from_taichi_object(func_call, (0,))
if isinstance(self.return_type, MatrixType):
return self.return_type.from_taichi_object(func_call, (0,))
raise TaichiTypeError(f"Unsupported return type: {self.return_type}")

def do_compile(self, key, args, arg_features):
Expand Down
13 changes: 13 additions & 0 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,19 @@ def foo() -> float:
assert foo() == pytest.approx(21)


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_real_func_matrix_return():
@ti.experimental.real_func
def mat_ret() -> ti.math.mat2:
return ti.math.mat2(1, 2, 3, 4)

@ti.kernel
def foo() -> ti.math.mat2:
return mat_ret()

assert (foo() == ti.math.mat2(1, 2, 3, 4)).all()


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_real_func_struct_ret():
s = ti.types.struct(a=ti.i16, b=ti.f64)
Expand Down

0 comments on commit 45c98dd

Please sign in to comment.