Skip to content

Commit

Permalink
[error] Throw proper error message when an Ndarray is passed in via t…
Browse files Browse the repository at this point in the history
…i.template (#5457)

ti.types.ndarray itself is a template type so let's throw a better error
message when users misuse it.
  • Loading branch information
ailzhang authored Jul 19, 2022
1 parent 2e1675e commit 233123d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ def extract_arg(arg, anno):
return tuple(
TaichiCallableTemplateMapper.extract_arg(item, anno)
for item in arg)
if isinstance(arg, taichi.lang._ndarray.Ndarray):
raise TaichiRuntimeTypeError(
'Ndarray shouldn\'t be passed in via `ti.template()`, please annotate your kernel using `ti.types.ndarray(...)` instead'
)
return arg
if isinstance(anno, texture_type.TextureType):
return '#'
Expand Down
14 changes: 14 additions & 0 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,3 +660,17 @@ def func(a: ti.types.ndarray()):
for k in range(2):
for p in range(2):
assert a2[i, j][k, p] == k * k


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_ndarray_as_template():
@ti.kernel
def func(arr_src: ti.template(), arr_dst: ti.template()):
for i, j in ti.ndrange(*arr_src.shape):
arr_dst[i, j] = arr_src[i, j]

arr_0 = ti.ndarray(ti.f32, shape=(5, 10))
arr_1 = ti.ndarray(ti.f32, shape=(5, 10))
with pytest.raises(ti.TaichiRuntimeTypeError,
match=r"Ndarray shouldn't be passed in via"):
func(arr_0, arr_1)

0 comments on commit 233123d

Please sign in to comment.