diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index d0749a4d9722c..128218630cb33 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -421,6 +421,7 @@ def __init__(self): self.gdb_trigger = False self.excepthook = False self.experimental_real_function = False + self.short_circuit_operators = False def prepare_sandbox(): @@ -591,6 +592,7 @@ def init(arch=None, env_spec.add('gdb_trigger') env_spec.add('excepthook') env_spec.add('experimental_real_function') + env_spec.add('short_circuit_operators') # compiler configurations (ti.cfg): for key in dir(ti.cfg): @@ -620,6 +622,8 @@ def init(arch=None, impl.get_runtime().print_preprocessed = spec_cfg.print_preprocessed impl.get_runtime().experimental_real_function = \ spec_cfg.experimental_real_function + impl.get_runtime().short_circuit_operators = \ + spec_cfg.short_circuit_operators ti.set_logging_level(spec_cfg.log_level.lower()) if spec_cfg.excepthook: # TODO(#1405): add a way to restore old excepthook diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 3baf7ce6bdc77..d649606a2b2d5 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -476,17 +476,68 @@ def build_UnaryOp(ctx, node): node.ptr = op(node.operand.ptr) return node + @staticmethod + def build_short_circuit_and(operands): + if len(operands) == 1: + return operands[0].ptr + + val = ti.expr_init(None) + lhs = operands[0].ptr + ti.begin_frontend_if(lhs) + + ti.core.begin_frontend_if_true() + rhs = ASTTransformer.build_short_circuit_and(operands[1:]) + val.assign(rhs) + ti.core.pop_scope() + + ti.core.begin_frontend_if_false() + val.assign(0) + ti.core.pop_scope() + + return val + + @staticmethod + def build_short_circuit_or(operands): + if len(operands) == 1: + return operands[0].ptr + + val = ti.expr_init(None) + lhs = operands[0].ptr + ti.begin_frontend_if(lhs) + + ti.core.begin_frontend_if_true() + val.assign(1) + ti.core.pop_scope() + + ti.core.begin_frontend_if_false() + rhs = ASTTransformer.build_short_circuit_or(operands[1:]) + val.assign(rhs) + ti.core.pop_scope() + + return val + + @staticmethod + def build_normal_bool_op(op): + def inner(operands): + result = op(operands[0].ptr, operands[1].ptr) + for i in range(2, len(operands)): + result = op(result, operands[i].ptr) + return result + + return inner + @staticmethod def build_BoolOp(ctx, node): node.values = build_stmts(ctx, node.values) - op = { - ast.And: ti.logical_and, - ast.Or: ti.logical_or, - }.get(type(node.op)) - result = op(node.values[0].ptr, node.values[1].ptr) - for i in range(2, len(node.values)): - result = op(result, node.values[i].ptr) - node.ptr = result + ops = { + ast.And: ASTTransformer.build_short_circuit_and, + ast.Or: ASTTransformer.build_short_circuit_or, + } if impl.get_runtime().short_circuit_operators else { + ast.And: ASTTransformer.build_normal_bool_op(ti.logical_and), + ast.Or: ASTTransformer.build_normal_bool_op(ti.logical_or), + } + op = ops.get(type(node.op)) + node.ptr = op(node.values) return node @staticmethod diff --git a/tests/python/test_short_circuit.py b/tests/python/test_short_circuit.py new file mode 100644 index 0000000000000..40e45998eb95b --- /dev/null +++ b/tests/python/test_short_circuit.py @@ -0,0 +1,49 @@ +import taichi as ti + + +@ti.test(debug=True, short_circuit_operators=True) +def test_and_shorted(): + a = ti.field(ti.i32, shape=10) + + @ti.func + def explode() -> ti.i32: + return a[-1] + + @ti.kernel + def func() -> ti.i32: + return False and explode() + + assert func() == 0 + + +@ti.test(debug=True, short_circuit_operators=True) +def test_and_not_shorted(): + @ti.kernel + def func() -> ti.i32: + return True and False + + assert func() == 0 + + +@ti.test(debug=True, short_circuit_operators=True) +def test_or_shorted(): + a = ti.field(ti.i32, shape=10) + + @ti.func + def explode() -> ti.i32: + return a[-1] + + @ti.kernel + def func() -> ti.i32: + return True or explode() + + assert func() == 1 + + +@ti.test(debug=True, short_circuit_operators=True) +def test_or_not_shorted(): + @ti.kernel + def func() -> ti.i32: + return False or True + + assert func() == 1