Skip to content

Commit

Permalink
[Lang] Support ndarray argument for real function
Browse files Browse the repository at this point in the history
ghstack-source-id: e8d7bbed74d451affa1ab3db3864e0dc51417edd
Pull Request resolved: #8188
  • Loading branch information
lin-hitonami committed Jun 16, 2023
1 parent bf26b46 commit 0ebf69d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 9 deletions.
21 changes: 17 additions & 4 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import taichi.lang
from taichi._lib import core as _ti_core
from taichi.lang import impl, ops, runtime_ops
from taichi.lang.any_array import AnyArray
from taichi.lang._wrap_inspect import getsourcefile, getsourcelines
from taichi.lang.argpack import ArgPackType, ArgPack
from taichi.lang.ast import (
Expand Down Expand Up @@ -226,12 +227,12 @@ def __call__(self, *args, **kwargs):
if self.is_real_function:
if impl.get_runtime().current_kernel.autodiff_mode != AutodiffMode.NONE:
raise TaichiSyntaxError("Real function in gradient kernels unsupported.")
instance_id, _ = self.mapper.lookup(args)
instance_id, arg_features = self.mapper.lookup(args)
key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id)
if self.compiled is None:
self.compiled = {}
if key.instance_id not in self.compiled:
self.do_compile(key=key, args=args)
self.do_compile(key=key, args=args, arg_features=arg_features)
return self.func_call_rvalue(key=key, args=args)
tree, ctx = _get_tree_and_ctx(
self,
Expand All @@ -257,6 +258,12 @@ def func_call_rvalue(self, key, args):
non_template_args.append(ops.cast(args[i], anno))
elif isinstance(anno, primitive_types.RefType):
non_template_args.append(_ti_core.make_reference(args[i].ptr))
elif isinstance(anno, ndarray_type.NdarrayType):
if not isinstance(args[i], AnyArray):
raise TaichiTypeError(
f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}"
)
non_template_args.append(args[i].ptr)
else:
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args, real_func_arg=True)
Expand All @@ -274,8 +281,10 @@ def func_call_rvalue(self, key, args):
return self.return_type.from_taichi_object(func_call, (0,))
raise TaichiTypeError(f"Unsupported return type: {self.return_type}")

def do_compile(self, key, args):
tree, ctx = _get_tree_and_ctx(self, is_kernel=False, args=args, is_real_function=self.is_real_function)
def do_compile(self, key, args, arg_features):
tree, ctx = _get_tree_and_ctx(
self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function
)
fn = impl.get_runtime().prog.create_function(key)

def func_body():
Expand Down Expand Up @@ -403,6 +412,10 @@ def extract_arg(arg, anno, arg_name):
anno.check_matched(arg.get_type(), arg_name)
needs_grad = (arg.grad is not None) if anno.needs_grad is None else anno.needs_grad
return arg.element_type, len(arg.shape), needs_grad, anno.boundary
if isinstance(arg, AnyArray):
ty = arg.get_type()
anno.check_matched(arg.get_type(), arg_name)
return ty.element_type, len(arg.shape), ty.needs_grad, anno.boundary
# external arrays
shape = getattr(arg, "shape", None)
if shape is None:
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ void export_lang(py::module &m) {
py::class_<Function>(m, "Function")
.def("insert_scalar_param", &Function::insert_scalar_param)
.def("insert_arr_param", &Function::insert_arr_param)
.def("insert_ndarray_param", &Function::insert_ndarray_param)
.def("insert_texture_param", &Function::insert_texture_param)
.def("insert_pointer_param", &Function::insert_pointer_param)
.def("insert_rw_texture_param", &Function::insert_rw_texture_param)
Expand Down
43 changes: 38 additions & 5 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,15 +1005,48 @@ def test(x: ti.types.ndarray(dtype=ti.types.vector())):
@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_pass_ndarray_to_func():
@ti.func
def bar(weight: ti.types.ndarray(ti.f32, ndim=3)):
pass
def bar(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32:
return weight[1, 1, 1]

@ti.kernel
def foo(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32:
return bar(weight)

weight = ti.ndarray(dtype=ti.f32, shape=(2, 2, 2))
weight.fill(42.0)
assert foo(weight) == 42.0


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_pass_ndarray_to_real_func():
@ti.experimental.real_func
def bar(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32:
return weight[1, 1, 1]

@ti.kernel
def foo(weight: ti.types.ndarray(ti.f32, ndim=3)):
bar(weight)
def foo(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32:
return bar(weight)

weight = ti.ndarray(dtype=ti.f32, shape=(2, 2, 2))
foo(weight)
weight.fill(42.0)
assert foo(weight) == 42.0


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_pass_ndarray_outside_kernel_to_real_func():
weight = ti.ndarray(dtype=ti.f32, shape=(2, 2, 2))

@ti.experimental.real_func
def bar(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32:
return weight[1, 1, 1]

@ti.kernel
def foo() -> ti.f32:
return bar(weight)

weight.fill(42.0)
with pytest.raises(ti.TaichiTypeError, match=r"Expected ndarray in the kernel argument for argument weight"):
foo()


@test_utils.test(arch=supported_archs_taichi_ndarray)
Expand Down

0 comments on commit 0ebf69d

Please sign in to comment.