Skip to content

Commit

Permalink
[Lang] add option short_circuit_operators for shor-circuiting boole…
Browse files Browse the repository at this point in the history
…an ops (taichi-dev#3572)
  • Loading branch information
re-xyr committed Nov 28, 2021
1 parent 21d9048 commit ecc5e92
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 8 deletions.
4 changes: 4 additions & 0 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def __init__(self):
self.gdb_trigger = False
self.excepthook = False
self.experimental_real_function = False
self.short_circuit_operators = False


def prepare_sandbox():
Expand Down Expand Up @@ -591,6 +592,7 @@ def init(arch=None,
env_spec.add('gdb_trigger')
env_spec.add('excepthook')
env_spec.add('experimental_real_function')
env_spec.add('short_circuit_operators')

# compiler configurations (ti.cfg):
for key in dir(ti.cfg):
Expand Down Expand Up @@ -620,6 +622,8 @@ def init(arch=None,
impl.get_runtime().print_preprocessed = spec_cfg.print_preprocessed
impl.get_runtime().experimental_real_function = \
spec_cfg.experimental_real_function
impl.get_runtime().short_circuit_operators = \
spec_cfg.short_circuit_operators
ti.set_logging_level(spec_cfg.log_level.lower())
if spec_cfg.excepthook:
# TODO(#1405): add a way to restore old excepthook
Expand Down
67 changes: 59 additions & 8 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,17 +476,68 @@ def build_UnaryOp(ctx, node):
node.ptr = op(node.operand.ptr)
return node

@staticmethod
def build_short_circuit_and(operands):
if len(operands) == 1:
return operands[0].ptr

val = ti.expr_init(None)
lhs = operands[0].ptr
ti.begin_frontend_if(lhs)

ti.core.begin_frontend_if_true()
rhs = ASTTransformer.build_short_circuit_and(operands[1:])
val.assign(rhs)
ti.core.pop_scope()

ti.core.begin_frontend_if_false()
val.assign(0)
ti.core.pop_scope()

return val

@staticmethod
def build_short_circuit_or(operands):
if len(operands) == 1:
return operands[0].ptr

val = ti.expr_init(None)
lhs = operands[0].ptr
ti.begin_frontend_if(lhs)

ti.core.begin_frontend_if_true()
val.assign(1)
ti.core.pop_scope()

ti.core.begin_frontend_if_false()
rhs = ASTTransformer.build_short_circuit_or(operands[1:])
val.assign(rhs)
ti.core.pop_scope()

return val

@staticmethod
def build_normal_bool_op(op):
def inner(operands):
result = op(operands[0].ptr, operands[1].ptr)
for i in range(2, len(operands)):
result = op(result, operands[i].ptr)
return result

return inner

@staticmethod
def build_BoolOp(ctx, node):
node.values = build_stmts(ctx, node.values)
op = {
ast.And: ti.logical_and,
ast.Or: ti.logical_or,
}.get(type(node.op))
result = op(node.values[0].ptr, node.values[1].ptr)
for i in range(2, len(node.values)):
result = op(result, node.values[i].ptr)
node.ptr = result
ops = {
ast.And: ASTTransformer.build_short_circuit_and,
ast.Or: ASTTransformer.build_short_circuit_or,
} if impl.get_runtime().short_circuit_operators else {
ast.And: ASTTransformer.build_normal_bool_op(ti.logical_and),
ast.Or: ASTTransformer.build_normal_bool_op(ti.logical_or),
}
op = ops.get(type(node.op))
node.ptr = op(node.values)
return node

@staticmethod
Expand Down
49 changes: 49 additions & 0 deletions tests/python/test_short_circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import taichi as ti


@ti.test(debug=True, short_circuit_operators=True)
def test_and_shorted():
a = ti.field(ti.i32, shape=10)

@ti.func
def explode() -> ti.i32:
return a[-1]

@ti.kernel
def func() -> ti.i32:
return False and explode()

assert func() == 0


@ti.test(debug=True, short_circuit_operators=True)
def test_and_not_shorted():
@ti.kernel
def func() -> ti.i32:
return True and False

assert func() == 0


@ti.test(debug=True, short_circuit_operators=True)
def test_or_shorted():
a = ti.field(ti.i32, shape=10)

@ti.func
def explode() -> ti.i32:
return a[-1]

@ti.kernel
def func() -> ti.i32:
return True or explode()

assert func() == 1


@ti.test(debug=True, short_circuit_operators=True)
def test_or_not_shorted():
@ti.kernel
def func() -> ti.i32:
return False or True

assert func() == 1

0 comments on commit ecc5e92

Please sign in to comment.