Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Opt] Improve the binary ops simplify pass #1646

Merged
merged 3 commits into from
Aug 6, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions taichi/transforms/binary_op_simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,25 @@ class BinaryOpSimp : public BasicStmtVisitor {
using BasicStmtVisitor::visit;
bool fast_math;
DelayedIRModifier modifier;
bool operand_swapped;

explicit BinaryOpSimp(bool fast_math_)
: BasicStmtVisitor(), fast_math(fast_math_) {
: BasicStmtVisitor(), fast_math(fast_math_), operand_swapped(false) {
}

void visit(BinaryOpStmt *stmt) override {
// swap lhs and rhs if lhs is a const and op is commutative
// Swap lhs and rhs if lhs is a const and op is commutative.
auto const_lhs = stmt->lhs->cast<ConstStmt>();
if (const_lhs && is_commutative(stmt->op_type) &&
!stmt->rhs->is<ConstStmt>()) {
auto rhs_stmt = stmt->rhs;
stmt->lhs = rhs_stmt;
stmt->rhs = const_lhs;
operand_swapped = true;
}
if (!fast_math) {
// Disable other optimizations if fast_math=True and the data type is not
// integral.
if (!fast_math && !is_integral(stmt->ret_type.data_type)) {
return;
}
auto binary_lhs = stmt->lhs->cast<BinaryOpStmt>();
Expand All @@ -36,16 +40,25 @@ class BinaryOpSimp : public BasicStmtVisitor {
if (!const_lhs_rhs || binary_lhs->lhs->is<ConstStmt>()) {
return;
}
auto op1 = binary_lhs->op_type;
auto op2 = stmt->op_type;
// Disables (a / b) * c -> a / (b / c), (a * b) / c -> a * (b / c)
// when the data type is integral.
if (is_integral(stmt->ret_type.data_type) &&
((op1 == BinaryOpType::div && op2 == BinaryOpType::mul) ||
(op1 == BinaryOpType::mul && op2 == BinaryOpType::div))) {
return;
}
BinaryOpType new_op2;
// original:
// stmt = (a op1 b) op2 c
// rearrange to:
// stmt = a op1 (b op2 c)
if (can_rearrange_associative(binary_lhs->op_type, stmt->op_type)) {
auto bin_op =
Stmt::make<BinaryOpStmt>(stmt->op_type, const_lhs_rhs, const_rhs);
if (can_rearrange_associative(op1, op2, new_op2)) {
auto bin_op = Stmt::make<BinaryOpStmt>(new_op2, const_lhs_rhs, const_rhs);
bin_op->ret_type.data_type = stmt->ret_type.data_type;
auto new_stmt = Stmt::make<BinaryOpStmt>(binary_lhs->op_type,
binary_lhs->lhs, bin_op.get());
auto new_stmt =
Stmt::make<BinaryOpStmt>(op1, binary_lhs->lhs, bin_op.get());
new_stmt->ret_type.data_type = stmt->ret_type.data_type;

modifier.insert_before(stmt, std::move(bin_op));
Expand All @@ -55,19 +68,32 @@ class BinaryOpSimp : public BasicStmtVisitor {
}
}

static bool can_rearrange_associative(BinaryOpType op1, BinaryOpType op2) {
if (op1 == BinaryOpType::add &&
static bool can_rearrange_associative(BinaryOpType op1,
BinaryOpType op2,
BinaryOpType &new_op2) {
if ((op1 == BinaryOpType::add || op1 == BinaryOpType::sub) &&
(op2 == BinaryOpType::add || op2 == BinaryOpType::sub)) {
if (op1 == BinaryOpType::add)
new_op2 = op2;
else
new_op2 =
(op2 == BinaryOpType::add ? BinaryOpType::sub : BinaryOpType::add);
return true;
}
if (op1 == BinaryOpType::mul &&
if ((op1 == BinaryOpType::mul || op1 == BinaryOpType::div) &&
(op2 == BinaryOpType::mul || op2 == BinaryOpType::div)) {
if (op1 == BinaryOpType::mul)
new_op2 = op2;
else
new_op2 =
(op2 == BinaryOpType::mul ? BinaryOpType::div : BinaryOpType::mul);
return true;
}
// for bit operations it only holds when two ops are the same
if ((op1 == BinaryOpType::bit_and || op1 == BinaryOpType::bit_or ||
op1 == BinaryOpType::bit_xor) &&
op1 == op2) {
new_op2 = op2;
return true;
}
return false;
Expand All @@ -89,7 +115,7 @@ class BinaryOpSimp : public BasicStmtVisitor {
} else
break;
}
return modified;
return modified || simplifier.operand_swapped;
}
};

Expand Down