From 3ffcccf56a07bc2cac2a2c4829426f3c43f1b602 Mon Sep 17 00:00:00 2001 From: Xy Ren Date: Sat, 27 Nov 2021 12:55:22 +0800 Subject: [PATCH] [Lang] add option `short_circuit_operators` for shor-circuiting boolean ops (#3572) --- python/taichi/lang/__init__.py | 4 ++ python/taichi/lang/ast/ast_transformer.py | 37 ++++++++++++++++- tests/python/test_short_circuit.py | 49 +++++++++++++++++++++++ 3 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 tests/python/test_short_circuit.py diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index d0749a4d9722c7..2118148f00ac79 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -421,6 +421,7 @@ def __init__(self): self.gdb_trigger = False self.excepthook = False self.experimental_real_function = False + self.short_circuit_operator = False def prepare_sandbox(): @@ -591,6 +592,7 @@ def init(arch=None, env_spec.add('gdb_trigger') env_spec.add('excepthook') env_spec.add('experimental_real_function') + env_spec.add('short_circuit_operators') # compiler configurations (ti.cfg): for key in dir(ti.cfg): @@ -620,6 +622,8 @@ def init(arch=None, impl.get_runtime().print_preprocessed = spec_cfg.print_preprocessed impl.get_runtime().experimental_real_function = \ spec_cfg.experimental_real_function + impl.get_runtime().short_circuit_operators= \ + spec_cfg.short_circuit_operators ti.set_logging_level(spec_cfg.log_level.lower()) if spec_cfg.excepthook: # TODO(#1405): add a way to restore old excepthook diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 3baf7ce6bdc77e..a9b4f57b414bd9 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -476,13 +476,46 @@ def build_UnaryOp(ctx, node): node.ptr = op(node.operand.ptr) return node + + @staticmethod + def build_short_circuit_and(lhs, rhs): + val = ti.expr_init(None) + + ti.begin_frontend_if(lhs) + ti.core.begin_frontend_if_true() + val.assign(rhs) + ti.core.pop_scope() + ti.core.begin_frontend_if_false() + val.assign(0) + ti.core.pop_scope() + + return val + + @staticmethod + def build_short_circuit_or(lhs, rhs): + val = ti.expr_init(None) + + ti.begin_frontend_if(lhs) + ti.core.begin_frontend_if_true() + val.assign(1) + ti.core.pop_scope() + ti.core.begin_frontend_if_false() + val.assign(rhs) + ti.core.pop_scope() + + return val + @staticmethod def build_BoolOp(ctx, node): node.values = build_stmts(ctx, node.values) - op = { + ops = { + ast.And: ASTTransformer.build_short_circuit_and, + ast.Or: ASTTransformer.build_short_circuit_or, + } if impl.get_runtime().short_circuit_operators else { ast.And: ti.logical_and, ast.Or: ti.logical_or, - }.get(type(node.op)) + } + op = ops.get(type(node.op)) result = op(node.values[0].ptr, node.values[1].ptr) for i in range(2, len(node.values)): result = op(result, node.values[i].ptr) diff --git a/tests/python/test_short_circuit.py b/tests/python/test_short_circuit.py new file mode 100644 index 00000000000000..8d9604af1a9d18 --- /dev/null +++ b/tests/python/test_short_circuit.py @@ -0,0 +1,49 @@ +import taichi as ti + + +@ti.test(short_circuit_operators=True) +def test_and_shorted(): + a = ti.field(ti.i32, shape=10) + + @ti.func + def explode() -> ti.i32: + return a[-1] + + @ti.kernel + def func() -> ti.i32: + return False and explode() + + assert func() == 0 + + +@ti.test(short_circuit_operators=True) +def test_and_not_shorted(): + @ti.kernel + def func() -> ti.i32: + return True and False + + assert func() == 0 + + +@ti.test(short_circuit_operators=True) +def test_or_shorted(): + a = ti.field(ti.i32, shape=10) + + @ti.func + def explode() -> ti.i32: + return a[-1] + + @ti.kernel + def func() -> ti.i32: + return True or explode() + + assert func() == 1 + + +@ti.test(short_circuit_operators=True) +def test_or_not_shorted(): + @ti.kernel + def func() -> ti.i32: + return False or True + + assert func() == 1