diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 89b6a549bd75e..bfdfa23985a1b 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -12,6 +12,7 @@ import taichi.lang from taichi._lib import core as _ti_core from taichi.lang import impl, ops, runtime_ops +from taichi.lang.any_array import AnyArray from taichi.lang._wrap_inspect import getsourcefile, getsourcelines from taichi.lang.argpack import ArgPackType, ArgPack from taichi.lang.ast import ( @@ -226,12 +227,12 @@ def __call__(self, *args, **kwargs): if self.is_real_function: if impl.get_runtime().current_kernel.autodiff_mode != AutodiffMode.NONE: raise TaichiSyntaxError("Real function in gradient kernels unsupported.") - instance_id, _ = self.mapper.lookup(args) + instance_id, arg_features = self.mapper.lookup(args) key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id) if self.compiled is None: self.compiled = {} if key.instance_id not in self.compiled: - self.do_compile(key=key, args=args) + self.do_compile(key=key, args=args, arg_features=arg_features) return self.func_call_rvalue(key=key, args=args) tree, ctx = _get_tree_and_ctx( self, @@ -257,6 +258,12 @@ def func_call_rvalue(self, key, args): non_template_args.append(ops.cast(args[i], anno)) elif isinstance(anno, primitive_types.RefType): non_template_args.append(_ti_core.make_reference(args[i].ptr)) + elif isinstance(anno, ndarray_type.NdarrayType): + if not isinstance(args[i], AnyArray): + raise TaichiTypeError( + f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}" + ) + non_template_args.append(args[i].ptr) else: non_template_args.append(args[i]) non_template_args = impl.make_expr_group(non_template_args, real_func_arg=True) @@ -274,8 +281,10 @@ def func_call_rvalue(self, key, args): return self.return_type.from_taichi_object(func_call, (0,)) raise TaichiTypeError(f"Unsupported return type: {self.return_type}") - def do_compile(self, key, args): - tree, ctx = _get_tree_and_ctx(self, is_kernel=False, args=args, is_real_function=self.is_real_function) + def do_compile(self, key, args, arg_features): + tree, ctx = _get_tree_and_ctx( + self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function + ) fn = impl.get_runtime().prog.create_function(key) def func_body(): @@ -403,6 +412,10 @@ def extract_arg(arg, anno, arg_name): anno.check_matched(arg.get_type(), arg_name) needs_grad = (arg.grad is not None) if anno.needs_grad is None else anno.needs_grad return arg.element_type, len(arg.shape), needs_grad, anno.boundary + if isinstance(arg, AnyArray): + ty = arg.get_type() + anno.check_matched(arg.get_type(), arg_name) + return ty.element_type, len(arg.shape), ty.needs_grad, anno.boundary # external arrays shape = getattr(arg, "shape", None) if shape is None: diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index f13307e5b506d..8d773321c820a 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -717,6 +717,7 @@ void export_lang(py::module &m) { py::class_(m, "Function") .def("insert_scalar_param", &Function::insert_scalar_param) .def("insert_arr_param", &Function::insert_arr_param) + .def("insert_ndarray_param", &Function::insert_ndarray_param) .def("insert_texture_param", &Function::insert_texture_param) .def("insert_pointer_param", &Function::insert_pointer_param) .def("insert_rw_texture_param", &Function::insert_rw_texture_param) diff --git a/tests/python/test_ndarray.py b/tests/python/test_ndarray.py index 55bb980256cb8..3182613d8499a 100644 --- a/tests/python/test_ndarray.py +++ b/tests/python/test_ndarray.py @@ -1005,15 +1005,48 @@ def test(x: ti.types.ndarray(dtype=ti.types.vector())): @test_utils.test(arch=supported_archs_taichi_ndarray) def test_pass_ndarray_to_func(): @ti.func - def bar(weight: ti.types.ndarray(ti.f32, ndim=3)): - pass + def bar(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32: + return weight[1, 1, 1] + + @ti.kernel + def foo(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32: + return bar(weight) + + weight = ti.ndarray(dtype=ti.f32, shape=(2, 2, 2)) + weight.fill(42.0) + assert foo(weight) == 42.0 + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_pass_ndarray_to_real_func(): + @ti.experimental.real_func + def bar(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32: + return weight[1, 1, 1] @ti.kernel - def foo(weight: ti.types.ndarray(ti.f32, ndim=3)): - bar(weight) + def foo(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32: + return bar(weight) weight = ti.ndarray(dtype=ti.f32, shape=(2, 2, 2)) - foo(weight) + weight.fill(42.0) + assert foo(weight) == 42.0 + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_pass_ndarray_outside_kernel_to_real_func(): + weight = ti.ndarray(dtype=ti.f32, shape=(2, 2, 2)) + + @ti.experimental.real_func + def bar(weight: ti.types.ndarray(ti.f32, ndim=3)) -> ti.f32: + return weight[1, 1, 1] + + @ti.kernel + def foo() -> ti.f32: + return bar(weight) + + weight.fill(42.0) + with pytest.raises(ti.TaichiTypeError, match=r"Expected ndarray in the kernel argument for argument weight"): + foo() @test_utils.test(arch=supported_archs_taichi_ndarray)