Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Add option short_circuit_operators for short-circuiting boolean ops #3632

Merged
merged 1 commit into from
Nov 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def __init__(self):
self.gdb_trigger = False
self.excepthook = False
self.experimental_real_function = False
self.short_circuit_operators = False


def prepare_sandbox():
Expand Down Expand Up @@ -591,6 +592,7 @@ def init(arch=None,
env_spec.add('gdb_trigger')
env_spec.add('excepthook')
env_spec.add('experimental_real_function')
env_spec.add('short_circuit_operators')

# compiler configurations (ti.cfg):
for key in dir(ti.cfg):
Expand Down Expand Up @@ -620,6 +622,8 @@ def init(arch=None,
impl.get_runtime().print_preprocessed = spec_cfg.print_preprocessed
impl.get_runtime().experimental_real_function = \
spec_cfg.experimental_real_function
impl.get_runtime().short_circuit_operators = \
spec_cfg.short_circuit_operators
ti.set_logging_level(spec_cfg.log_level.lower())
if spec_cfg.excepthook:
# TODO(#1405): add a way to restore old excepthook
Expand Down
67 changes: 59 additions & 8 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,17 +476,68 @@ def build_UnaryOp(ctx, node):
node.ptr = op(node.operand.ptr)
return node

@staticmethod
def build_short_circuit_and(operands):
if len(operands) == 1:
return operands[0].ptr
strongoier marked this conversation as resolved.
Show resolved Hide resolved

val = ti.expr_init(None)
lhs = operands[0].ptr
ti.begin_frontend_if(lhs)

ti.core.begin_frontend_if_true()
rhs = ASTTransformer.build_short_circuit_and(operands[1:])
val.assign(rhs)
ti.core.pop_scope()

ti.core.begin_frontend_if_false()
val.assign(0)
ti.core.pop_scope()

return val

@staticmethod
def build_short_circuit_or(operands):
if len(operands) == 1:
return operands[0].ptr

val = ti.expr_init(None)
lhs = operands[0].ptr
ti.begin_frontend_if(lhs)

ti.core.begin_frontend_if_true()
val.assign(1)
ti.core.pop_scope()

ti.core.begin_frontend_if_false()
rhs = ASTTransformer.build_short_circuit_or(operands[1:])
val.assign(rhs)
ti.core.pop_scope()

return val

@staticmethod
def build_normal_bool_op(op):
def inner(operands):
result = op(operands[0].ptr, operands[1].ptr)
for i in range(2, len(operands)):
result = op(result, operands[i].ptr)
return result

return inner

@staticmethod
def build_BoolOp(ctx, node):
node.values = build_stmts(ctx, node.values)
op = {
ast.And: ti.logical_and,
ast.Or: ti.logical_or,
}.get(type(node.op))
result = op(node.values[0].ptr, node.values[1].ptr)
for i in range(2, len(node.values)):
result = op(result, node.values[i].ptr)
node.ptr = result
ops = {
ast.And: ASTTransformer.build_short_circuit_and,
ast.Or: ASTTransformer.build_short_circuit_or,
} if impl.get_runtime().short_circuit_operators else {
ast.And: ASTTransformer.build_normal_bool_op(ti.logical_and),
ast.Or: ASTTransformer.build_normal_bool_op(ti.logical_or),
}
op = ops.get(type(node.op))
node.ptr = op(node.values)
return node

@staticmethod
Expand Down
49 changes: 49 additions & 0 deletions tests/python/test_short_circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import taichi as ti


@ti.test(debug=True, short_circuit_operators=True)
def test_and_shorted():
a = ti.field(ti.i32, shape=10)

@ti.func
def explode() -> ti.i32:
return a[-1]
re-xyr marked this conversation as resolved.
Show resolved Hide resolved

@ti.kernel
def func() -> ti.i32:
return False and explode()

assert func() == 0


@ti.test(debug=True, short_circuit_operators=True)
def test_and_not_shorted():
@ti.kernel
def func() -> ti.i32:
return True and False

assert func() == 0


@ti.test(debug=True, short_circuit_operators=True)
def test_or_shorted():
a = ti.field(ti.i32, shape=10)

@ti.func
def explode() -> ti.i32:
return a[-1]

@ti.kernel
def func() -> ti.i32:
return True or explode()

assert func() == 1


@ti.test(debug=True, short_circuit_operators=True)
def test_or_not_shorted():
@ti.kernel
def func() -> ti.i32:
return False or True

assert func() == 1