diff --git a/python/taichi/lang/_ndrange.py b/python/taichi/lang/_ndrange.py index bde0e7aab1f04..51ce24288def1 100644 --- a/python/taichi/lang/_ndrange.py +++ b/python/taichi/lang/_ndrange.py @@ -1,5 +1,6 @@ import collections.abc +import numpy as np from taichi.lang import ops from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.expr import Expr @@ -20,7 +21,7 @@ def __init__(self, *args): args[i] = (args[i][0], ops.max(args[i][0], args[i][1])) for arg in args: for bound in arg: - if not isinstance(bound, int) and not ( + if not isinstance(bound, (int, np.integer)) and not ( isinstance(bound, Expr) and is_integral(bound.ptr.get_ret_type())): raise TaichiTypeError( diff --git a/tests/python/test_ndrange.py b/tests/python/test_ndrange.py index fda679df820a5..abf0b5618593d 100644 --- a/tests/python/test_ndrange.py +++ b/tests/python/test_ndrange.py @@ -278,6 +278,18 @@ def example(): example() +@test_utils.test() +def test_ndrange_should_accept_numpy_integer(): + a, b = np.int64(0), np.int32(10) + + @ti.kernel + def example(): + for i in ti.ndrange((a, b)): + pass + + example() + + @test_utils.test() def test_static_ndrange_non_integer_arguments(): @ti.kernel @@ -291,3 +303,15 @@ def example(): r"Every argument of ndrange should be an integer scalar or a tuple/list of \(int, int\)" ): example() + + +@test_utils.test() +def test_static_ndrange_should_accept_numpy_integer(): + a, b = np.int64(0), np.int32(10) + + @ti.kernel + def example(): + for i in ti.static(ti.ndrange((a, b))): + pass + + example()