Skip to content

Commit

Permalink
[Lang] [ir] Move short-circuit boolean logic into AST-to-IR passes (#…
Browse files Browse the repository at this point in the history
…4580)

* [Lang] [ir] Move short-circuit boolean logic into AST-to-IR passes

* [Lang] Enable short circuit bool operators by default

* [ir] Remove true_mask and false_mask in IfStmt & mask_var in Block

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
re-xyr and pre-commit-ci[bot] authored Apr 18, 2022
1 parent 88181e9 commit c171215
Show file tree
Hide file tree
Showing 17 changed files with 98 additions and 113 deletions.
72 changes: 15 additions & 57 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,64 +637,23 @@ def build_UnaryOp(ctx, node):
return node.ptr

@staticmethod
def build_short_circuit_and(ast_builder, operands):
if len(operands) == 1:
return operands[0].ptr

val = impl.expr_init(None)
lhs = operands[0].ptr
impl.begin_frontend_if(ast_builder, lhs)

ast_builder.begin_frontend_if_true()
rhs = ASTTransformer.build_short_circuit_and(ast_builder, operands[1:])
val._assign(rhs)
ast_builder.pop_scope()

ast_builder.begin_frontend_if_false()
val._assign(0)
ast_builder.pop_scope()

return val

@staticmethod
def build_short_circuit_or(ast_builder, operands):
if len(operands) == 1:
return operands[0].ptr

val = impl.expr_init(None)
lhs = operands[0].ptr
impl.begin_frontend_if(ast_builder, lhs)

ast_builder.begin_frontend_if_true()
val._assign(1)
ast_builder.pop_scope()

ast_builder.begin_frontend_if_false()
rhs = ASTTransformer.build_short_circuit_or(ast_builder, operands[1:])
val._assign(rhs)
ast_builder.pop_scope()

return val

@staticmethod
def build_normal_bool_op(op):
def inner(ast_builder, operands):
result = op(operands[0].ptr, operands[1].ptr)
for i in range(2, len(operands)):
result = op(result, operands[i].ptr)
return result
def build_bool_op(op):
def inner(operands):
if len(operands) == 1:
return operands[0].ptr
return op(operands[0].ptr, inner(operands[1:]))

return inner

@staticmethod
def build_static_short_circuit_and(ast_builder, operands):
def build_static_and(operands):
for operand in operands:
if not operand.ptr:
return operand.ptr
return operands[-1].ptr

@staticmethod
def build_static_short_circuit_or(ast_builder, operands):
def build_static_or(operands):
for operand in operands:
if operand.ptr:
return operand.ptr
Expand All @@ -705,22 +664,21 @@ def build_BoolOp(ctx, node):
build_stmts(ctx, node.values)
if ctx.is_in_static_scope():
ops = {
ast.And: ASTTransformer.build_static_short_circuit_and,
ast.Or: ASTTransformer.build_static_short_circuit_or,
ast.And: ASTTransformer.build_static_and,
ast.Or: ASTTransformer.build_static_or,
}
elif impl.get_runtime().short_circuit_operators:
ops = {
ast.And: ASTTransformer.build_short_circuit_and,
ast.Or: ASTTransformer.build_short_circuit_or,
ast.And: ASTTransformer.build_bool_op(ti_ops.logical_and),
ast.Or: ASTTransformer.build_bool_op(ti_ops.logical_or),
}
else:
ops = {
ast.And:
ASTTransformer.build_normal_bool_op(ti_ops.logical_and),
ast.Or: ASTTransformer.build_normal_bool_op(ti_ops.logical_or),
ast.And: ASTTransformer.build_bool_op(ti_ops.bit_and),
ast.Or: ASTTransformer.build_bool_op(ti_ops.bit_or),
}
op = ops.get(type(node.op))
node.ptr = op(ctx.ast_builder, node.values)
node.ptr = op(node.values)
return node.ptr

@staticmethod
Expand Down Expand Up @@ -765,7 +723,7 @@ def build_Compare(ctx, node):
raise TaichiSyntaxError(
f'"{type(node_op).__name__}" is not supported in Taichi kernels.'
)
val = ti_ops.logical_and(val, op(l, r))
val = ti_ops.bit_and(val, op(l, r))
node.ptr = val
return node.ptr

Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class _SpecialConfig:
def __init__(self):
self.log_level = 'info'
self.gdb_trigger = False
self.short_circuit_operators = False
self.short_circuit_operators = True


def prepare_sandbox():
Expand Down
33 changes: 30 additions & 3 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,9 +1083,36 @@ def bit_shr(x1, x2):
return _binary_operation(_ti_core.expr_bit_shr, _bt_ops_mod.rshift, x1, x2)


# We don't have logic_and/or instructions yet:
logical_or = bit_or
logical_and = bit_and
@binary
def logical_and(a, b):
"""Compute logical_and
Args:
a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): value LHS
b (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): value RHS
Returns:
Union[:class:`~taichi.lang.expr.Expr`, bool]: LHS logical-and RHS (with short-circuit semantics)
"""
return _binary_operation(_ti_core.expr_logical_and, lambda a, b: a and b,
a, b)


@binary
def logical_or(a, b):
"""Compute logical_or
Args:
a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): value LHS
b (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): value RHS
Returns:
Union[:class:`~taichi.lang.expr.Expr`, bool]: LHS logical-or RHS (with short-circuit semantics)
"""
return _binary_operation(_ti_core.expr_logical_or, lambda a, b: a or b, a,
b)


@ternary
Expand Down
8 changes: 1 addition & 7 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1140,12 +1140,7 @@ void CodeGenLLVM::visit(LocalLoadStmt *stmt) {
}

void CodeGenLLVM::visit(LocalStoreStmt *stmt) {
auto mask = stmt->parent->mask();
if (mask && stmt->width() != 1) {
TI_NOT_IMPLEMENTED
} else {
builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]);
}
builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]);
}

void CodeGenLLVM::visit(AssertStmt *stmt) {
Expand Down Expand Up @@ -1362,7 +1357,6 @@ void CodeGenLLVM::visit(GlobalPtrStmt *stmt) {
}

void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
TI_ASSERT(!stmt->parent->mask() || stmt->width() == 1);
TI_ASSERT(llvm_val[stmt->val]);
TI_ASSERT(llvm_val[stmt->dest]);
auto ptr_type = stmt->dest->ret_type->as<PointerType>();
Expand Down
2 changes: 2 additions & 0 deletions taichi/inc/binary_op.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ PER_BINARY_OP(cmp_ne)
PER_BINARY_OP(atan2)
PER_BINARY_OP(pow)
PER_BINARY_OP(undefined)
PER_BINARY_OP(logical_or)
PER_BINARY_OP(logical_and)
8 changes: 4 additions & 4 deletions taichi/ir/expression_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ DEFINE_EXPRESSION_OP_BINARY(-, sub)
DEFINE_EXPRESSION_OP_BINARY(*, mul)
DEFINE_EXPRESSION_OP_BINARY(/, div)
DEFINE_EXPRESSION_OP_BINARY(%, mod)
DEFINE_EXPRESSION_OP_BINARY(&&, bit_and)
DEFINE_EXPRESSION_OP_BINARY(||, bit_or)
// DEFINE_EXPRESSION_OP_BINARY(&, bit_and)
// DEFINE_EXPRESSION_OP_BINARY(|, bit_or)
DEFINE_EXPRESSION_OP_BINARY(&&, logical_and)
DEFINE_EXPRESSION_OP_BINARY(||, logical_or)
DEFINE_EXPRESSION_OP_BINARY(&, bit_and)
DEFINE_EXPRESSION_OP_BINARY(|, bit_or)
DEFINE_EXPRESSION_OP_BINARY(^, bit_xor)
DEFINE_EXPRESSION_OP_BINARY(<<, bit_shl)
DEFINE_EXPRESSION_OP_BINARY(>>, bit_sar)
Expand Down
33 changes: 32 additions & 1 deletion taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ void BinaryOpExpression::type_check(CompileConfig *config) {
if (binary_is_bitwise(type) &&
(!is_integral(lhs_type) || !is_integral(rhs_type)))
error();
if (is_comparison(type)) {
if (binary_is_logical(type) &&
(lhs_type != PrimitiveType::i32 || rhs_type != PrimitiveType::i32))
error();
if (is_comparison(type) || binary_is_logical(type)) {
ret_type = PrimitiveType::i32;
return;
}
Expand All @@ -210,6 +213,34 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) {
// if (stmt)
// return;
flatten_rvalue(lhs, ctx);
if (binary_is_logical(type)) {
auto result = ctx->push_back<AllocaStmt>(ret_type);
ctx->push_back<LocalStoreStmt>(result, lhs->stmt);
auto cond = ctx->push_back<LocalLoadStmt>(LocalAddress(result, 0));
auto if_stmt = ctx->push_back<IfStmt>(cond);

FlattenContext rctx;
rctx.current_block = ctx->current_block;
flatten_rvalue(rhs, &rctx);
rctx.push_back<LocalStoreStmt>(result, rhs->stmt);

auto true_block = std::make_unique<Block>();
if (type == BinaryOpType::logical_and) {
true_block->set_statements(std::move(rctx.stmts));
}
if_stmt->set_true_statements(std::move(true_block));

auto false_block = std::make_unique<Block>();
if (type == BinaryOpType::logical_or) {
false_block->set_statements(std::move(rctx.stmts));
}
if_stmt->set_false_statements(std::move(false_block));

auto ret = ctx->push_back<LocalLoadStmt>(LocalAddress(result, 0));
ret->tb = tb;
stmt = ret;
return;
}
flatten_rvalue(rhs, ctx);
ctx->push_back(std::make_unique<BinaryOpStmt>(type, lhs->stmt, rhs->stmt));
ctx->stmts.back()->tb = tb;
Expand Down
11 changes: 0 additions & 11 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,6 @@ Stmt *Block::lookup_var(const Identifier &ident) const {
}
}

Stmt *Block::mask() {
if (mask_var)
return mask_var;
else if (parent_block() == nullptr) {
return nullptr;
} else {
return parent_block()->mask();
}
}

void Block::set_statements(VecStatement &&stmts) {
statements.clear();
for (int i = 0; i < (int)stmts.size(); i++) {
Expand Down Expand Up @@ -410,7 +400,6 @@ stmt_vector::iterator Block::find(Stmt *stmt) {
std::unique_ptr<Block> Block::clone() const {
auto new_block = std::make_unique<Block>();
new_block->parent_stmt = parent_stmt;
new_block->mask_var = mask_var;
new_block->stop_gradients = stop_gradients;
new_block->statements.reserve(size());
for (auto &stmt : statements)
Expand Down
3 changes: 0 additions & 3 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,15 +603,13 @@ class Block : public IRNode {
Stmt *parent_stmt{nullptr};
stmt_vector statements;
stmt_vector trash_bin;
Stmt *mask_var{nullptr};
std::vector<SNode *> stop_gradients;

// Only used in frontend. Stores LoopIndexStmt or BinaryOpStmt for loop
// variables, and AllocaStmt for other variables.
std::map<Identifier, Stmt *> local_var_to_stmt;

Block() {
mask_var = nullptr;
parent_stmt = nullptr;
kernel = nullptr;
}
Expand Down Expand Up @@ -648,7 +646,6 @@ class Block : public IRNode {
VecStatement &&new_statements,
bool replace_usages = true);
Stmt *lookup_var(const Identifier &ident) const;
Stmt *mask();
IRNode *get_parent() const override;

Stmt *back() const {
Expand Down
5 changes: 1 addition & 4 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ bool LocalLoadStmt::has_source(Stmt *alloca) const {
return false;
}

IfStmt::IfStmt(Stmt *cond)
: cond(cond), true_mask(nullptr), false_mask(nullptr) {
IfStmt::IfStmt(Stmt *cond) : cond(cond) {
TI_STMT_REG_FIELDS;
}

Expand All @@ -235,8 +234,6 @@ void IfStmt::set_false_statements(

std::unique_ptr<Stmt> IfStmt::clone() const {
auto new_stmt = std::make_unique<IfStmt>(cond);
new_stmt->true_mask = true_mask;
new_stmt->false_mask = false_mask;
if (true_statements)
new_stmt->set_true_statements(true_statements->clone());
if (false_statements)
Expand Down
3 changes: 1 addition & 2 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,6 @@ class LocalStoreStmt : public Stmt {
class IfStmt : public Stmt {
public:
Stmt *cond;
Stmt *true_mask, *false_mask;
std::unique_ptr<Block> true_statements, false_statements;

explicit IfStmt(Stmt *cond);
Expand All @@ -672,7 +671,7 @@ class IfStmt : public Stmt {

std::unique_ptr<Stmt> clone() const override;

TI_STMT_DEF_FIELDS(cond, true_mask, false_mask);
TI_STMT_DEF_FIELDS(cond);
TI_DEFINE_ACCEPT
};

Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/stmt_op_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ std::string binary_op_type_symbol(BinaryOpType type) {
REGISTER_TYPE(cmp_eq, ==);
REGISTER_TYPE(bit_and, &);
REGISTER_TYPE(bit_or, |);
REGISTER_TYPE(logical_and, &&);
REGISTER_TYPE(logical_or, ||);
REGISTER_TYPE(bit_xor, ^);
REGISTER_TYPE(pow, pow);
REGISTER_TYPE(bit_shl, <<);
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/stmt_op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ inline bool binary_is_bitwise(BinaryOpType t) {
t == BinaryOpType ::bit_sar;
}

inline bool binary_is_logical(BinaryOpType t) {
return t == BinaryOpType ::logical_and || t == BinaryOpType ::logical_or;
}

std::string binary_op_type_name(BinaryOpType type);

inline bool is_comparison(BinaryOpType type) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/math/svd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ TLANG_NAMESPACE_BEGIN

template <typename Tf, typename Ti>
Expr svd_bitwise_or(const Expr &a, const Expr &b) {
return bit_cast<Tf>(bit_cast<Ti>(a) || bit_cast<Ti>(b));
return bit_cast<Tf>(bit_cast<Ti>(a) | bit_cast<Ti>(b));
}

template <typename Tf, typename Ti>
Expand All @@ -17,7 +17,7 @@ Expr svd_bitwise_xor(const Expr &a, const Expr &b) {

template <typename Tf, typename Ti>
Expr svd_bitwise_and(const Expr &a, const Expr &b) {
return bit_cast<Tf>(bit_cast<Ti>(a) && bit_cast<Ti>(b));
return bit_cast<Tf>(bit_cast<Ti>(a) & bit_cast<Ti>(b));
}

template <typename Tf, typename Ti>
Expand Down
3 changes: 3 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,10 @@ void export_lang(py::module &m) {
DEFINE_EXPRESSION_OP(bit_shr)
DEFINE_EXPRESSION_OP(bit_sar)
DEFINE_EXPRESSION_OP(bit_not)

DEFINE_EXPRESSION_OP(logic_not)
DEFINE_EXPRESSION_OP(logical_and)
DEFINE_EXPRESSION_OP(logical_or)

DEFINE_EXPRESSION_OP(add)
DEFINE_EXPRESSION_OP(sub)
Expand Down
Loading

0 comments on commit c171215

Please sign in to comment.