Skip to content

Commit

Permalink
[Lang] Allow implicit conversion of integer types in if conditions (#…
Browse files Browse the repository at this point in the history
…5763)

* [Lang] Make bool literals `i32` instead of `default_ip`

* [Lang] Implicitly convert other integer types to i32 in conditions

* [Lang] Add tests for implicit conversions for conditions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Lang] Add tests i32 bool literals

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yi Xu <xy_xuyi@foxmail.com>
  • Loading branch information
3 people authored Aug 23, 2022
1 parent b1e85ec commit 1390bcc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
5 changes: 5 additions & 0 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from taichi.lang.common_ops import TaichiOperations
from taichi.lang.exception import TaichiTypeError
from taichi.lang.util import is_taichi_class, to_numpy_type
from taichi.types import primitive_types
from taichi.types.primitive_types import integer_types, real_types


Expand Down Expand Up @@ -65,6 +66,10 @@ def _clamp_unsigned_to_range(npty, val):


def make_constant_expr(val, dtype):
if isinstance(val, bool):
constant_dtype = primitive_types.i32
return Expr(_ti_core.make_const_expr_int(constant_dtype, val))

if isinstance(val, (float, np.floating)):
constant_dtype = impl.get_runtime(
).default_fp if dtype is None else dtype
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/frontend_type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ namespace lang {

class FrontendTypeCheck : public IRVisitor {
void check_cond_type(const Expr &cond, std::string stmt_name) {
if (!cond->ret_type->is_primitive(PrimitiveTypeID::i32))
if (!cond->ret_type->is<PrimitiveType>() || !is_integral(cond->ret_type))
throw TaichiTypeError(fmt::format(
"`{0}` conditions must be of type i32; found {1}. Consider using "
"`{0}` conditions must be an integer; found {1}. Consider using "
"`{0} x != 0` instead of `{0} x` for float values.",
stmt_name, cond->ret_type->to_string()));
}
Expand Down
24 changes: 24 additions & 0 deletions tests/python/test_bool_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,27 @@ def func() -> ti.i32:
return ti.static(5 and 2 and 0)

assert func() == 0


@test_utils.test(require=ti.extension.data64, default_ip=ti.i64)
def test_condition_type():
@ti.kernel
def func() -> int:
x = 0
result = 0
if x:
result = 1
else:
result = 2
return result

assert func() == 2


@test_utils.test(require=ti.extension.data64, default_ip=ti.i64)
def test_i32_bool():
@ti.kernel
def func() -> ti.i32:
return True

assert func() == 1

0 comments on commit 1390bcc

Please sign in to comment.