Skip to content

Commit

Permalink
[Lang] [ir] Move short-circuit boolean logic into AST-to-IR passes
Browse files Browse the repository at this point in the history
  • Loading branch information
re-xyr committed Mar 19, 2022
1 parent 463cbea commit eafa0b1
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 64 deletions.
66 changes: 12 additions & 54 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,64 +621,23 @@ def build_UnaryOp(ctx, node):
return node.ptr

@staticmethod
def build_short_circuit_and(ast_builder, operands):
if len(operands) == 1:
return operands[0].ptr

val = impl.expr_init(None)
lhs = operands[0].ptr
impl.begin_frontend_if(ast_builder, lhs)

ast_builder.begin_frontend_if_true()
rhs = ASTTransformer.build_short_circuit_and(ast_builder, operands[1:])
val._assign(rhs)
ast_builder.pop_scope()

ast_builder.begin_frontend_if_false()
val._assign(0)
ast_builder.pop_scope()

return val

@staticmethod
def build_short_circuit_or(ast_builder, operands):
if len(operands) == 1:
return operands[0].ptr

val = impl.expr_init(None)
lhs = operands[0].ptr
impl.begin_frontend_if(ast_builder, lhs)

ast_builder.begin_frontend_if_true()
val._assign(1)
ast_builder.pop_scope()

ast_builder.begin_frontend_if_false()
rhs = ASTTransformer.build_short_circuit_or(ast_builder, operands[1:])
val._assign(rhs)
ast_builder.pop_scope()

return val

@staticmethod
def build_normal_bool_op(op):
def build_bool_op(op):
def inner(ast_builder, operands):
result = op(operands[0].ptr, operands[1].ptr)
for i in range(2, len(operands)):
result = op(result, operands[i].ptr)
return result
if len(operands) == 1:
return operands[0].ptr
return op(operands[0].ptr, inner(ast_builder, operands[1:]))

return inner

@staticmethod
def build_static_short_circuit_and(ast_builder, operands):
def build_static_and(ast_builder, operands):
for operand in operands:
if not operand.ptr:
return operand.ptr
return operands[-1].ptr

@staticmethod
def build_static_short_circuit_or(ast_builder, operands):
def build_static_or(ast_builder, operands):
for operand in operands:
if operand.ptr:
return operand.ptr
Expand All @@ -689,19 +648,18 @@ def build_BoolOp(ctx, node):
build_stmts(ctx, node.values)
if ctx.is_in_static_scope():
ops = {
ast.And: ASTTransformer.build_static_short_circuit_and,
ast.Or: ASTTransformer.build_static_short_circuit_or,
ast.And: ASTTransformer.build_static_and,
ast.Or: ASTTransformer.build_static_or,
}
elif impl.get_runtime().short_circuit_operators:
ops = {
ast.And: ASTTransformer.build_short_circuit_and,
ast.Or: ASTTransformer.build_short_circuit_or,
ast.And: ASTTransformer.build_bool_op(ti_ops.logical_and),
ast.Or: ASTTransformer.build_bool_op(ti_ops.logical_or),
}
else:
ops = {
ast.And:
ASTTransformer.build_normal_bool_op(ti_ops.logical_and),
ast.Or: ASTTransformer.build_normal_bool_op(ti_ops.logical_or),
ast.And: ASTTransformer.build_bool_op(ti_ops.bit_and),
ast.Or: ASTTransformer.build_bool_op(ti_ops.bit_or),
}
op = ops.get(type(node.op))
node.ptr = op(ctx.ast_builder, node.values)
Expand Down
33 changes: 30 additions & 3 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,9 +1083,36 @@ def bit_shr(x1, x2):
return _binary_operation(_ti_core.expr_bit_shr, _bt_ops_mod.rshift, x1, x2)


# We don't have logic_and/or instructions yet:
logical_or = bit_or
logical_and = bit_and
@binary
def logical_and(a, b):
"""Compute logical_and
Args:
a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): value LHS
b (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): value RHS
Returns:
Union[:class:`~taichi.lang.expr.Expr`, bool]: LHS logical-and RHS (with short-circuit semantics)
"""
return _binary_operation(_ti_core.expr_logical_and, lambda a, b: a and b,
a, b)


@binary
def logical_or(a, b):
"""Compute logical_or
Args:
a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): value LHS
b (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): value RHS
Returns:
Union[:class:`~taichi.lang.expr.Expr`, bool]: LHS logical-or RHS (with short-circuit semantics)
"""
return _binary_operation(_ti_core.expr_logical_or, lambda a, b: a or b, a,
b)


@ternary
Expand Down
2 changes: 2 additions & 0 deletions taichi/inc/binary_op.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ PER_BINARY_OP(cmp_ne)
PER_BINARY_OP(atan2)
PER_BINARY_OP(pow)
PER_BINARY_OP(undefined)
PER_BINARY_OP(logical_or)
PER_BINARY_OP(logical_and)
8 changes: 4 additions & 4 deletions taichi/ir/expression_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ DEFINE_EXPRESSION_OP_BINARY(-, sub)
DEFINE_EXPRESSION_OP_BINARY(*, mul)
DEFINE_EXPRESSION_OP_BINARY(/, div)
DEFINE_EXPRESSION_OP_BINARY(%, mod)
DEFINE_EXPRESSION_OP_BINARY(&&, bit_and)
DEFINE_EXPRESSION_OP_BINARY(||, bit_or)
// DEFINE_EXPRESSION_OP_BINARY(&, bit_and)
// DEFINE_EXPRESSION_OP_BINARY(|, bit_or)
DEFINE_EXPRESSION_OP_BINARY(&&, logical_and)
DEFINE_EXPRESSION_OP_BINARY(||, logical_or)
DEFINE_EXPRESSION_OP_BINARY(&, bit_and)
DEFINE_EXPRESSION_OP_BINARY(|, bit_or)
DEFINE_EXPRESSION_OP_BINARY(^, bit_xor)
DEFINE_EXPRESSION_OP_BINARY(<<, bit_shl)
DEFINE_EXPRESSION_OP_BINARY(>>, bit_sar)
Expand Down
33 changes: 32 additions & 1 deletion taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ void BinaryOpExpression::type_check(CompileConfig *config) {
if (binary_is_bitwise(type) &&
(!is_integral(lhs_type) || !is_integral(rhs_type)))
error();
if (is_comparison(type)) {
if (binary_is_logical(type) &&
(lhs_type != PrimitiveType::i32 || rhs_type != PrimitiveType::i32))
error();
if (is_comparison(type) || binary_is_logical(type)) {
ret_type = PrimitiveType::i32;
return;
}
Expand All @@ -216,6 +219,34 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) {
// if (stmt)
// return;
flatten_rvalue(lhs, ctx);
if (binary_is_logical(type)) {
auto result = ctx->push_back<AllocaStmt>(ret_type);
ctx->push_back<LocalStoreStmt>(result, lhs->stmt);
auto cond = ctx->push_back<LocalLoadStmt>(LocalAddress(result, 0));
auto if_stmt = ctx->push_back<IfStmt>(cond);

FlattenContext rctx;
rctx.current_block = ctx->current_block;
flatten_rvalue(rhs, &rctx);
rctx.push_back<LocalStoreStmt>(result, rhs->stmt);

auto true_block = std::make_unique<Block>();
if (type == BinaryOpType::logical_and) {
true_block->set_statements(std::move(rctx.stmts));
}
if_stmt->set_true_statements(std::move(true_block));

auto false_block = std::make_unique<Block>();
if (type == BinaryOpType::logical_or) {
false_block->set_statements(std::move(rctx.stmts));
}
if_stmt->set_false_statements(std::move(false_block));

auto ret = ctx->push_back<LocalLoadStmt>(LocalAddress(result, 0));
ret->tb = tb;
stmt = ret;
return;
}
flatten_rvalue(rhs, ctx);
ctx->push_back(std::make_unique<BinaryOpStmt>(type, lhs->stmt, rhs->stmt));
ctx->stmts.back()->tb = tb;
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/stmt_op_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ std::string binary_op_type_symbol(BinaryOpType type) {
REGISTER_TYPE(cmp_eq, ==);
REGISTER_TYPE(bit_and, &);
REGISTER_TYPE(bit_or, |);
REGISTER_TYPE(logical_and, &&);
REGISTER_TYPE(logical_or, ||);
REGISTER_TYPE(bit_xor, ^);
REGISTER_TYPE(pow, pow);
REGISTER_TYPE(bit_shl, <<);
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/stmt_op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ inline bool binary_is_bitwise(BinaryOpType t) {
t == BinaryOpType ::bit_sar;
}

inline bool binary_is_logical(BinaryOpType t) {
return t == BinaryOpType ::logical_and || t == BinaryOpType ::logical_or;
}

std::string binary_op_type_name(BinaryOpType type);

inline bool is_comparison(BinaryOpType type) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/math/svd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ TLANG_NAMESPACE_BEGIN

template <typename Tf, typename Ti>
Expr svd_bitwise_or(const Expr &a, const Expr &b) {
return bit_cast<Tf>(bit_cast<Ti>(a) || bit_cast<Ti>(b));
return bit_cast<Tf>(bit_cast<Ti>(a) | bit_cast<Ti>(b));
}

template <typename Tf, typename Ti>
Expand All @@ -17,7 +17,7 @@ Expr svd_bitwise_xor(const Expr &a, const Expr &b) {

template <typename Tf, typename Ti>
Expr svd_bitwise_and(const Expr &a, const Expr &b) {
return bit_cast<Tf>(bit_cast<Ti>(a) && bit_cast<Ti>(b));
return bit_cast<Tf>(bit_cast<Ti>(a) & bit_cast<Ti>(b));
}

template <typename Tf, typename Ti>
Expand Down
3 changes: 3 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,10 @@ void export_lang(py::module &m) {
DEFINE_EXPRESSION_OP(bit_shr)
DEFINE_EXPRESSION_OP(bit_sar)
DEFINE_EXPRESSION_OP(bit_not)

DEFINE_EXPRESSION_OP(logic_not)
DEFINE_EXPRESSION_OP(logical_and)
DEFINE_EXPRESSION_OP(logical_or)

DEFINE_EXPRESSION_OP(add)
DEFINE_EXPRESSION_OP(sub)
Expand Down

0 comments on commit eafa0b1

Please sign in to comment.