Skip to content

Commit

Permalink
[Lang] [bug] Fix error on ndarray type check
Browse files Browse the repository at this point in the history
ghstack-source-id: 873b5498564cf01d0fb03080808ef880ed8183f0
Pull Request resolved: #8230
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Jun 29, 2023
1 parent a1c222d commit 9945f44
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/taichi/types/ndarray_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from taichi.lang.enums import Layout, to_boundary_enum
from taichi.types.compound_types import CompoundType, matrix, vector
from taichi.lang import util


class NdarrayTypeMetadata:
Expand Down Expand Up @@ -94,7 +95,7 @@ def check_matched(self, ndarray_type: NdarrayTypeMetadata, arg_name: str):
else:
if self.dtype is not None:
# Check dtype match for scalar.
if not self.dtype == ndarray_type.element_type:
if not util.cook_dtype(self.dtype) == ndarray_type.element_type:
raise TypeError(
f"Expect element type {self.dtype} for argument {arg_name}, but get {ndarray_type.element_type}"
)
Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,3 +1103,14 @@ def test(ao: ti.types.ndarray(dtype=ti.f32, ndim=2, boundary="clamp")):
ao = ti.ndarray(ti.f32, shape=(height, width))
test(ao)
assert (ao.to_numpy() == np.zeros((height, width))).all()


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_ndarray_arg_builtin_float_type():
@ti.kernel
def foo(x: ti.types.ndarray(float, ndim=0)) -> ti.f32:
return x[None]

x = ti.ndarray(ti.f32, shape=())
x[None] = 42
assert foo(x) == 42

0 comments on commit 9945f44

Please sign in to comment.