diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 9a5b7d1e062dbf..c06985194d6692 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -899,9 +899,10 @@ def any(self): True """ ret = False - for i in range(0, len(self.entries)): + # we must use range+len, because entries doesn't have __iter__ + for i in range(0, len(self.entries)): # pylint: disable=consider-using-enumerate ret = ret | ops_mod.cmp_ne(self.entries[i], 0) - return ops_mod.cmp_ne(ret, 0) & 1 + return ret & True def all(self): """Test whether all element not equal zero. @@ -916,9 +917,9 @@ def all(self): False """ ret = True - for i in range(0, len(self.entries)): + for i in range(0, len(self.entries)): # pylint: disable=consider-using-enumerate ret = ret & ops_mod.cmp_ne(self.entries[i], 0) - return ops_mod.cmp_ne(ret, 0) & 1 + return ret @taichi_scope def fill(self, val): diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 70a187a5bacf4a..25564bad52f6b0 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -917,7 +917,7 @@ def cmp_lt(a, b): Union[:class:`~taichi.lang.expr.Expr`, bool]: True if LHS is strictly smaller than RHS, False otherwise """ - return _binary_operation(_ti_core.expr_cmp_lt,lambda a,b: int(a < b),a, + return _binary_operation(_ti_core.expr_cmp_lt, lambda a, b: int(a < b), a, b) @@ -933,8 +933,8 @@ def cmp_le(a, b): Union[:class:`~taichi.lang.expr.Expr`, bool]: True if LHS is smaller than or equal to RHS, False otherwise """ - return _binary_operation(_ti_core.expr_cmp_le,lambda a,b: int(a <= b), - a,b) + return _binary_operation(_ti_core.expr_cmp_le, lambda a, b: int(a <= b), a, + b) @binary @@ -949,7 +949,7 @@ def cmp_gt(a, b): Union[:class:`~taichi.lang.expr.Expr`, bool]: True if LHS is strictly larger than RHS, False otherwise """ - return _binary_operation(_ti_core.expr_cmp_gt,lambda a,b: int(a > b),a, + return _binary_operation(_ti_core.expr_cmp_gt, lambda a, b: int(a > b), a, b) @@ -965,8 +965,8 @@ def cmp_ge(a, b): bool: True if LHS is greater than or equal to RHS, False otherwise """ - return _binary_operation(_ti_core.expr_cmp_ge,lambda a,b: int(a >= b), - a,b) + return _binary_operation(_ti_core.expr_cmp_ge, lambda a, b: int(a >= b), a, + b) @binary @@ -981,8 +981,8 @@ def cmp_eq(a, b): Union[:class:`~taichi.lang.expr.Expr`, bool]: True if LHS is equal to RHS, False otherwise. """ - return _binary_operation(_ti_core.expr_cmp_eq,lambda a,b: int(a == b), - a,b) + return _binary_operation(_ti_core.expr_cmp_eq, lambda a, b: int(a == b), a, + b) @binary @@ -997,8 +997,8 @@ def cmp_ne(a, b): Union[:class:`~taichi.lang.expr.Expr`, bool]: True if LHS is not equal to RHS, False otherwise """ - return _binary_operation(_ti_core.expr_cmp_ne,lambda a,b: int(a != b), - a,b) + return _binary_operation(_ti_core.expr_cmp_ne, lambda a, b: int(a != b), a, + b) @binary diff --git a/tests/python/test_compare.py b/tests/python/test_compare.py index fe7f4049c3efc2..9607d4c7a687bd 100644 --- a/tests/python/test_compare.py +++ b/tests/python/test_compare.py @@ -219,5 +219,5 @@ def foo(): @test_utils.test() def test_python_scope_compare(): - v = ti.math.vec3(0,1,2) + v = ti.math.vec3(0, 1, 2) assert (v < 1)[0] == 1