diff --git a/taichi/transforms/binary_op_simplify.cpp b/taichi/transforms/binary_op_simplify.cpp index 933e844d8807a..a07a47a97d55b 100644 --- a/taichi/transforms/binary_op_simplify.cpp +++ b/taichi/transforms/binary_op_simplify.cpp @@ -10,21 +10,25 @@ class BinaryOpSimp : public BasicStmtVisitor { using BasicStmtVisitor::visit; bool fast_math; DelayedIRModifier modifier; + bool operand_swapped; explicit BinaryOpSimp(bool fast_math_) - : BasicStmtVisitor(), fast_math(fast_math_) { + : BasicStmtVisitor(), fast_math(fast_math_), operand_swapped(false) { } void visit(BinaryOpStmt *stmt) override { - // swap lhs and rhs if lhs is a const and op is commutative + // Swap lhs and rhs if lhs is a const and op is commutative. auto const_lhs = stmt->lhs->cast(); if (const_lhs && is_commutative(stmt->op_type) && !stmt->rhs->is()) { auto rhs_stmt = stmt->rhs; stmt->lhs = rhs_stmt; stmt->rhs = const_lhs; + operand_swapped = true; } - if (!fast_math) { + // Disable other optimizations if fast_math=True and the data type is not + // integral. + if (!fast_math && !is_integral(stmt->ret_type.data_type)) { return; } auto binary_lhs = stmt->lhs->cast(); @@ -36,16 +40,25 @@ class BinaryOpSimp : public BasicStmtVisitor { if (!const_lhs_rhs || binary_lhs->lhs->is()) { return; } + auto op1 = binary_lhs->op_type; + auto op2 = stmt->op_type; + // Disables (a / b) * c -> a / (b / c), (a * b) / c -> a * (b / c) + // when the data type is integral. + if (is_integral(stmt->ret_type.data_type) && + ((op1 == BinaryOpType::div && op2 == BinaryOpType::mul) || + (op1 == BinaryOpType::mul && op2 == BinaryOpType::div))) { + return; + } + BinaryOpType new_op2; // original: // stmt = (a op1 b) op2 c // rearrange to: // stmt = a op1 (b op2 c) - if (can_rearrange_associative(binary_lhs->op_type, stmt->op_type)) { - auto bin_op = - Stmt::make(stmt->op_type, const_lhs_rhs, const_rhs); + if (can_rearrange_associative(op1, op2, new_op2)) { + auto bin_op = Stmt::make(new_op2, const_lhs_rhs, const_rhs); bin_op->ret_type.data_type = stmt->ret_type.data_type; - auto new_stmt = Stmt::make(binary_lhs->op_type, - binary_lhs->lhs, bin_op.get()); + auto new_stmt = + Stmt::make(op1, binary_lhs->lhs, bin_op.get()); new_stmt->ret_type.data_type = stmt->ret_type.data_type; modifier.insert_before(stmt, std::move(bin_op)); @@ -55,19 +68,32 @@ class BinaryOpSimp : public BasicStmtVisitor { } } - static bool can_rearrange_associative(BinaryOpType op1, BinaryOpType op2) { - if (op1 == BinaryOpType::add && + static bool can_rearrange_associative(BinaryOpType op1, + BinaryOpType op2, + BinaryOpType &new_op2) { + if ((op1 == BinaryOpType::add || op1 == BinaryOpType::sub) && (op2 == BinaryOpType::add || op2 == BinaryOpType::sub)) { + if (op1 == BinaryOpType::add) + new_op2 = op2; + else + new_op2 = + (op2 == BinaryOpType::add ? BinaryOpType::sub : BinaryOpType::add); return true; } - if (op1 == BinaryOpType::mul && + if ((op1 == BinaryOpType::mul || op1 == BinaryOpType::div) && (op2 == BinaryOpType::mul || op2 == BinaryOpType::div)) { + if (op1 == BinaryOpType::mul) + new_op2 = op2; + else + new_op2 = + (op2 == BinaryOpType::mul ? BinaryOpType::div : BinaryOpType::mul); return true; } // for bit operations it only holds when two ops are the same if ((op1 == BinaryOpType::bit_and || op1 == BinaryOpType::bit_or || op1 == BinaryOpType::bit_xor) && op1 == op2) { + new_op2 = op2; return true; } return false; @@ -89,7 +115,7 @@ class BinaryOpSimp : public BasicStmtVisitor { } else break; } - return modified; + return modified || simplifier.operand_swapped; } };