Skip to content

Commit

Permalink
[lang] [bug] Fix setting integer arguments within u64 range but great…
Browse files Browse the repository at this point in the history
…er than i64 range (#6267)

Issue: fixes #6264 

### Brief Summary

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Oct 11, 2022
1 parent d4399ce commit f97e48c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 6 deletions.
20 changes: 15 additions & 5 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from taichi.lang.kernel_arguments import KernelArgument
from taichi.lang.matrix import Matrix, MatrixType
from taichi.lang.shell import _shell_pop_print, oinspect
from taichi.lang.util import has_paddle, has_pytorch, to_taichi_type
from taichi.lang.util import (cook_dtype, has_paddle, has_pytorch,
to_taichi_type)
from taichi.types import (ndarray_type, primitive_types, sparse_matrix_builder,
template, texture_type)
from taichi.types.utils import is_signed

from taichi import _logging

Expand Down Expand Up @@ -661,10 +663,14 @@ def func__(*args):
if not isinstance(v, int):
raise TaichiRuntimeTypeError.get(
i, needed.to_string(), provided)
launch_ctx.set_arg_int(actual_argument_slot, int(v))
if is_signed(cook_dtype(needed)):
launch_ctx.set_arg_int(actual_argument_slot, int(v))
else:
launch_ctx.set_arg_uint(actual_argument_slot, int(v))
elif isinstance(needed, sparse_matrix_builder):
# Pass only the base pointer of the ti.types.sparse_matrix_builder() argument
launch_ctx.set_arg_int(actual_argument_slot, v._get_addr())
launch_ctx.set_arg_uint(actual_argument_slot,
v._get_addr())
elif isinstance(needed,
ndarray_type.NdarrayType) and isinstance(
v, taichi.lang._ndarray.Ndarray):
Expand Down Expand Up @@ -743,8 +749,12 @@ def func__(*args):
if not isinstance(val, int):
raise TaichiRuntimeTypeError.get(
i, needed.dtype.to_string(), type(val))
launch_ctx.set_arg_int(actual_argument_slot,
int(val))
if is_signed(needed.dtype):
launch_ctx.set_arg_int(
actual_argument_slot, int(val))
else:
launch_ctx.set_arg_uint(
actual_argument_slot, int(val))
actual_argument_slot += 1
else:
raise ValueError(
Expand Down
6 changes: 5 additions & 1 deletion taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ void Kernel::LaunchContextBuilder::set_arg_int(int arg_id, int64 d) {
"not allowed.");

ActionRecorder::get_instance().record(
"set_kernel_arg_int64",
"set_kernel_arg_integer",
{ActionArg("kernel_name", kernel_->name), ActionArg("arg_id", arg_id),
ActionArg("val", d)});

Expand All @@ -218,6 +218,10 @@ void Kernel::LaunchContextBuilder::set_arg_int(int arg_id, int64 d) {
}
}

void Kernel::LaunchContextBuilder::set_arg_uint(int arg_id, uint64 d) {
set_arg_int(arg_id, d);
}

void Kernel::LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) {
ctx_->extra_args[i][j] = d;
}
Expand Down
2 changes: 2 additions & 0 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ class TI_DLL_EXPORT Kernel : public Callable {

void set_arg_float(int arg_id, float64 d);

// Created signed and unsigned version for argument range check of pybind
void set_arg_int(int arg_id, int64 d);
void set_arg_uint(int arg_id, uint64 d);

void set_extra_arg_int(int i, int j, int32 d);

Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ void export_lang(py::module &m) {

py::class_<Kernel::LaunchContextBuilder>(m, "KernelLaunchContext")
.def("set_arg_int", &Kernel::LaunchContextBuilder::set_arg_int)
.def("set_arg_uint", &Kernel::LaunchContextBuilder::set_arg_uint)
.def("set_arg_float", &Kernel::LaunchContextBuilder::set_arg_float)
.def("set_arg_external_array_with_shape",
&Kernel::LaunchContextBuilder::set_arg_external_array_with_shape)
Expand Down
15 changes: 15 additions & 0 deletions tests/python/test_kernel_arg_errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import platform

import pytest

import taichi as ti
Expand All @@ -18,6 +20,19 @@ def foo(a: ti.i32):
foo(1.2)


@test_utils.test(exclude=[ti.metal])
def test_pass_u64():
if ti.lang.impl.current_cfg().arch == ti.vulkan and platform.system(
) == "Darwin":
return

@ti.kernel
def foo(a: ti.u64):
pass

foo(2**64 - 1)


@test_utils.test()
def test_argument_redefinition():
@ti.kernel
Expand Down

0 comments on commit f97e48c

Please sign in to comment.