Skip to content

Commit

Permalink
[Lang] Make python scope comparison return 1 instead of -1
Browse files Browse the repository at this point in the history
  • Loading branch information
re-xyr committed Aug 22, 2022
1 parent 0fe8982 commit 5085962
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 18 deletions.
16 changes: 8 additions & 8 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,10 +898,10 @@ def any(self):
>>> v.any()
True
"""
ret = ops_mod.cmp_ne(self.entries[0], 0)
for i in range(1, len(self.entries)):
ret = ret + ops_mod.cmp_ne(self.entries[i], 0)
return -ops_mod.cmp_lt(ret, 0)
ret = False
for i in range(0, len(self.entries)):
ret = ret | ops_mod.cmp_ne(self.entries[i], 0)
return ops_mod.cmp_ne(ret, 0) & 1

def all(self):
"""Test whether all element not equal zero.
Expand All @@ -915,10 +915,10 @@ def all(self):
>>> v.all()
False
"""
ret = ops_mod.cmp_ne(self.entries[0], 0)
for i in range(1, len(self.entries)):
ret = ret + ops_mod.cmp_ne(self.entries[i], 0)
return -ops_mod.cmp_eq(ret, -len(self.entries))
ret = True
for i in range(0, len(self.entries)):
ret = ret & ops_mod.cmp_ne(self.entries[i], 0)
return ops_mod.cmp_ne(ret, 0) & 1

@taichi_scope
def fill(self, val):
Expand Down
20 changes: 10 additions & 10 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def cmp_lt(a, b):
Union[:class:`~taichi.lang.expr.Expr`, bool]: True if LHS is strictly smaller than RHS, False otherwise
"""
return _binary_operation(_ti_core.expr_cmp_lt, lambda a, b: -int(a < b), a,
return _binary_operation(_ti_core.expr_cmp_lt,lambda a,b: int(a < b),a,
b)


Expand All @@ -933,8 +933,8 @@ def cmp_le(a, b):
Union[:class:`~taichi.lang.expr.Expr`, bool]: True if LHS is smaller than or equal to RHS, False otherwise
"""
return _binary_operation(_ti_core.expr_cmp_le, lambda a, b: -int(a <= b),
a, b)
return _binary_operation(_ti_core.expr_cmp_le,lambda a,b: int(a <= b),
a,b)


@binary
Expand All @@ -949,7 +949,7 @@ def cmp_gt(a, b):
Union[:class:`~taichi.lang.expr.Expr`, bool]: True if LHS is strictly larger than RHS, False otherwise
"""
return _binary_operation(_ti_core.expr_cmp_gt, lambda a, b: -int(a > b), a,
return _binary_operation(_ti_core.expr_cmp_gt,lambda a,b: int(a > b),a,
b)


Expand All @@ -965,8 +965,8 @@ def cmp_ge(a, b):
bool: True if LHS is greater than or equal to RHS, False otherwise
"""
return _binary_operation(_ti_core.expr_cmp_ge, lambda a, b: -int(a >= b),
a, b)
return _binary_operation(_ti_core.expr_cmp_ge,lambda a,b: int(a >= b),
a,b)


@binary
Expand All @@ -981,8 +981,8 @@ def cmp_eq(a, b):
Union[:class:`~taichi.lang.expr.Expr`, bool]: True if LHS is equal to RHS, False otherwise.
"""
return _binary_operation(_ti_core.expr_cmp_eq, lambda a, b: -int(a == b),
a, b)
return _binary_operation(_ti_core.expr_cmp_eq,lambda a,b: int(a == b),
a,b)


@binary
Expand All @@ -997,8 +997,8 @@ def cmp_ne(a, b):
Union[:class:`~taichi.lang.expr.Expr`, bool]: True if LHS is not equal to RHS, False otherwise
"""
return _binary_operation(_ti_core.expr_cmp_ne, lambda a, b: -int(a != b),
a, b)
return _binary_operation(_ti_core.expr_cmp_ne,lambda a,b: int(a != b),
a,b)


@binary
Expand Down
6 changes: 6 additions & 0 deletions tests/python/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,9 @@ def foo():
i -= 1

foo()


@test_utils.test()
def test_python_scope_compare():
v = ti.math.vec3(0,1,2)
assert (v < 1)[0] == 1

0 comments on commit 5085962

Please sign in to comment.