diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 42288d138b282..4a4208f928360 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -21,9 +21,11 @@ from taichi.lang.kernel_arguments import KernelArgument from taichi.lang.matrix import Matrix, MatrixType from taichi.lang.shell import _shell_pop_print, oinspect -from taichi.lang.util import has_paddle, has_pytorch, to_taichi_type +from taichi.lang.util import (cook_dtype, has_paddle, has_pytorch, + to_taichi_type) from taichi.types import (ndarray_type, primitive_types, sparse_matrix_builder, template, texture_type) +from taichi.types.utils import is_signed from taichi import _logging @@ -661,10 +663,14 @@ def func__(*args): if not isinstance(v, int): raise TaichiRuntimeTypeError.get( i, needed.to_string(), provided) - launch_ctx.set_arg_int(actual_argument_slot, int(v)) + if is_signed(cook_dtype(needed)): + launch_ctx.set_arg_int(actual_argument_slot, int(v)) + else: + launch_ctx.set_arg_uint(actual_argument_slot, int(v)) elif isinstance(needed, sparse_matrix_builder): # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument - launch_ctx.set_arg_int(actual_argument_slot, v._get_addr()) + launch_ctx.set_arg_uint(actual_argument_slot, + v._get_addr()) elif isinstance(needed, ndarray_type.NdarrayType) and isinstance( v, taichi.lang._ndarray.Ndarray): @@ -743,8 +749,12 @@ def func__(*args): if not isinstance(val, int): raise TaichiRuntimeTypeError.get( i, needed.dtype.to_string(), type(val)) - launch_ctx.set_arg_int(actual_argument_slot, - int(val)) + if is_signed(needed.dtype): + launch_ctx.set_arg_int( + actual_argument_slot, int(val)) + else: + launch_ctx.set_arg_uint( + actual_argument_slot, int(val)) actual_argument_slot += 1 else: raise ValueError( diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index abedf47778423..fb18a55c8b4bd 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -191,7 +191,7 @@ void Kernel::LaunchContextBuilder::set_arg_int(int arg_id, int64 d) { "not allowed."); ActionRecorder::get_instance().record( - "set_kernel_arg_int64", + "set_kernel_arg_integer", {ActionArg("kernel_name", kernel_->name), ActionArg("arg_id", arg_id), ActionArg("val", d)}); @@ -218,6 +218,10 @@ void Kernel::LaunchContextBuilder::set_arg_int(int arg_id, int64 d) { } } +void Kernel::LaunchContextBuilder::set_arg_uint(int arg_id, uint64 d) { + set_arg_int(arg_id, d); +} + void Kernel::LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) { ctx_->extra_args[i][j] = d; } diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index 4c7bca71030a6..7b0bf7d9cded8 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -35,7 +35,9 @@ class TI_DLL_EXPORT Kernel : public Callable { void set_arg_float(int arg_id, float64 d); + // Created signed and unsigned version for argument range check of pybind void set_arg_int(int arg_id, int64 d); + void set_arg_uint(int arg_id, uint64 d); void set_extra_arg_int(int i, int j, int32 d); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 5b4014c9b09ae..e3a3fddac9a12 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -715,6 +715,7 @@ void export_lang(py::module &m) { py::class_(m, "KernelLaunchContext") .def("set_arg_int", &Kernel::LaunchContextBuilder::set_arg_int) + .def("set_arg_uint", &Kernel::LaunchContextBuilder::set_arg_uint) .def("set_arg_float", &Kernel::LaunchContextBuilder::set_arg_float) .def("set_arg_external_array_with_shape", &Kernel::LaunchContextBuilder::set_arg_external_array_with_shape) diff --git a/tests/python/test_kernel_arg_errors.py b/tests/python/test_kernel_arg_errors.py index ba3759893bc56..23a495a371d04 100644 --- a/tests/python/test_kernel_arg_errors.py +++ b/tests/python/test_kernel_arg_errors.py @@ -1,3 +1,5 @@ +import platform + import pytest import taichi as ti @@ -18,6 +20,19 @@ def foo(a: ti.i32): foo(1.2) +@test_utils.test(exclude=[ti.metal]) +def test_pass_u64(): + if ti.lang.impl.current_cfg().arch == ti.vulkan and platform.system( + ) == "Darwin": + return + + @ti.kernel + def foo(a: ti.u64): + pass + + foo(2**64 - 1) + + @test_utils.test() def test_argument_redefinition(): @ti.kernel