From 42a9cae3ecd3ab364a4daf40a911470cc0f6fd0d Mon Sep 17 00:00:00 2001 From: Proton Date: Mon, 4 Jul 2022 18:43:05 +0800 Subject: [PATCH 1/2] [bug] Accept numpy integers in ndrange (#5245) --- python/taichi/lang/_ndrange.py | 3 ++- tests/python/test_ndrange.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/_ndrange.py b/python/taichi/lang/_ndrange.py index bde0e7aab1f04..51ce24288def1 100644 --- a/python/taichi/lang/_ndrange.py +++ b/python/taichi/lang/_ndrange.py @@ -1,5 +1,6 @@ import collections.abc +import numpy as np from taichi.lang import ops from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.expr import Expr @@ -20,7 +21,7 @@ def __init__(self, *args): args[i] = (args[i][0], ops.max(args[i][0], args[i][1])) for arg in args: for bound in arg: - if not isinstance(bound, int) and not ( + if not isinstance(bound, (int, np.integer)) and not ( isinstance(bound, Expr) and is_integral(bound.ptr.get_ret_type())): raise TaichiTypeError( diff --git a/tests/python/test_ndrange.py b/tests/python/test_ndrange.py index fda679df820a5..db6fbe8559cce 100644 --- a/tests/python/test_ndrange.py +++ b/tests/python/test_ndrange.py @@ -278,6 +278,18 @@ def example(): example() +@test_utils.test() +def test_ndrange_should_accept_numpy_integer(): + a, b = np.int64(0), np.int32(10) + + @ti.kernel + def example(): + for i in ti.ndrange((a, b)): + pass + + example() + + @test_utils.test() def test_static_ndrange_non_integer_arguments(): @ti.kernel From 6e2012765425d65e64fb458065592e70a2e0c9df Mon Sep 17 00:00:00 2001 From: Proton Date: Mon, 4 Jul 2022 18:51:36 +0800 Subject: [PATCH 2/2] add static ndrange test --- tests/python/test_ndrange.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/python/test_ndrange.py b/tests/python/test_ndrange.py index db6fbe8559cce..abf0b5618593d 100644 --- a/tests/python/test_ndrange.py +++ b/tests/python/test_ndrange.py @@ -303,3 +303,15 @@ def example(): r"Every argument of ndrange should be an integer scalar or a tuple/list of \(int, int\)" ): example() + + +@test_utils.test() +def test_static_ndrange_should_accept_numpy_integer(): + a, b = np.int64(0), np.int32(10) + + @ti.kernel + def example(): + for i in ti.static(ti.ndrange((a, b))): + pass + + example()