Skip to content

Commit

Permalink
[opt] Add a DelayedIRModifier class (#1103)
Browse files Browse the repository at this point in the history
* [opt] Add a DelayedIRModifier class

* fix tests

* add assertions in the destructor
  • Loading branch information
xumingkuan authored Jun 2, 2020
1 parent 18a0eb5 commit 43e1da5
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 43 deletions.
34 changes: 34 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,40 @@ std::unique_ptr<Block> Block::clone() const {
return new_block;
}

DelayedIRModifier::~DelayedIRModifier() {
TI_ASSERT(to_insert_before.empty());
TI_ASSERT(to_erase.empty());
}

void DelayedIRModifier::erase(Stmt *stmt) {
to_erase.push_back(stmt);
}

void DelayedIRModifier::insert_before(Stmt *old_statement,
std::unique_ptr<Stmt> new_statements) {
to_insert_before.emplace_back(old_statement,
VecStatement(std::move(new_statements)));
}

void DelayedIRModifier::insert_before(Stmt *old_statement,
VecStatement &&new_statements) {
to_insert_before.emplace_back(old_statement, std::move(new_statements));
}

bool DelayedIRModifier::modify_ir() {
if (to_insert_before.empty() && to_erase.empty())
return false;
for (auto &i : to_insert_before) {
i.first->parent->insert_before(i.first, std::move(i.second));
}
to_insert_before.clear();
for (auto &stmt : to_erase) {
stmt->parent->erase(stmt);
}
to_erase.clear();
return true;
}

FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type,
SNode *snode,
const ExprGroup &indices,
Expand Down
13 changes: 13 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,19 @@ class Block : public IRNode {
TI_DEFINE_ACCEPT
};

class DelayedIRModifier {
private:
std::vector<std::pair<Stmt *, VecStatement>> to_insert_before;
std::vector<Stmt *> to_erase;

public:
~DelayedIRModifier();
void erase(Stmt *stmt);
void insert_before(Stmt *old_statement, std::unique_ptr<Stmt> new_statement);
void insert_before(Stmt *old_statement, VecStatement &&new_statements);
bool modify_ir();
};

class SNodeOpStmt : public Stmt {
public:
SNodeOpType op_type;
Expand Down
78 changes: 35 additions & 43 deletions taichi/transforms/alg_simp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ class AlgSimp : public BasicStmtVisitor {
cast->cast_type = stmt->ret_type.data_type;
cast->ret_type.data_type = stmt->ret_type.data_type;
a = cast.get();
to_insert_before.emplace_back(std::move(cast), stmt);
modifier.insert_before(stmt, std::move(cast));
}
}

public:
static constexpr int max_weaken_exponent = 32;
using BasicStmtVisitor::visit;
bool fast_math;
std::vector<Stmt *> to_erase;
std::vector<std::pair<std::unique_ptr<Stmt>, Stmt *>> to_insert_before;
DelayedIRModifier modifier;

explicit AlgSimp(bool fast_math_)
: BasicStmtVisitor(), fast_math(fast_math_) {
Expand All @@ -44,40 +43,40 @@ class AlgSimp : public BasicStmtVisitor {
if (alg_is_zero(rhs)) {
// a +-|^ 0 -> a
stmt->replace_with(stmt->lhs);
to_erase.push_back(stmt);
modifier.erase(stmt);
} else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs)) {
// 0 +|^ a -> a
stmt->replace_with(stmt->rhs);
to_erase.push_back(stmt);
modifier.erase(stmt);
}
} else if (stmt->op_type == BinaryOpType::mul ||
stmt->op_type == BinaryOpType::div) {
if (alg_is_one(rhs)) {
// a */ 1 -> a
stmt->replace_with(stmt->lhs);
to_erase.push_back(stmt);
modifier.erase(stmt);
} else if (stmt->op_type == BinaryOpType::mul && alg_is_one(lhs)) {
// 1 * a -> a
stmt->replace_with(stmt->rhs);
to_erase.push_back(stmt);
modifier.erase(stmt);
} else if ((fast_math || is_integral(stmt->ret_type.data_type)) &&
stmt->op_type == BinaryOpType::mul &&
(alg_is_zero(lhs) || alg_is_zero(rhs))) {
// fast_math or integral operands: 0 * a -> 0, a * 0 -> 0
if (alg_is_zero(lhs) &&
lhs->ret_type.data_type == stmt->ret_type.data_type) {
stmt->replace_with(stmt->lhs);
to_erase.push_back(stmt);
modifier.erase(stmt);
} else if (alg_is_zero(rhs) &&
rhs->ret_type.data_type == stmt->ret_type.data_type) {
stmt->replace_with(stmt->rhs);
to_erase.push_back(stmt);
modifier.erase(stmt);
} else {
auto zero = Stmt::make<ConstStmt>(
LaneAttribute<TypedConstant>(stmt->ret_type.data_type));
stmt->replace_with(zero.get());
to_insert_before.emplace_back(std::move(zero), stmt);
to_erase.push_back(stmt);
modifier.insert_before(stmt, std::move(zero));
modifier.erase(stmt);
}
} else if (stmt->op_type == BinaryOpType::mul &&
(alg_is_two(lhs) || alg_is_two(rhs))) {
Expand All @@ -89,8 +88,8 @@ class AlgSimp : public BasicStmtVisitor {
auto sum = Stmt::make<BinaryOpStmt>(BinaryOpType::add, a, a);
sum->ret_type.data_type = a->ret_type.data_type;
stmt->replace_with(sum.get());
to_insert_before.emplace_back(std::move(sum), stmt);
to_erase.push_back(stmt);
modifier.insert_before(stmt, std::move(sum));
modifier.erase(stmt);
} else if (fast_math && stmt->op_type == BinaryOpType::div && rhs &&
is_real(rhs->ret_type.data_type)) {
if (alg_is_zero(rhs)) {
Expand All @@ -112,33 +111,33 @@ class AlgSimp : public BasicStmtVisitor {
reciprocal.get());
product->ret_type.data_type = stmt->ret_type.data_type;
stmt->replace_with(product.get());
to_insert_before.emplace_back(std::move(reciprocal), stmt);
to_insert_before.emplace_back(std::move(product), stmt);
to_erase.push_back(stmt);
modifier.insert_before(stmt, std::move(reciprocal));
modifier.insert_before(stmt, std::move(product));
modifier.erase(stmt);
}
}
} else if (stmt->op_type == BinaryOpType::pow) {
if (alg_is_one(rhs)) {
// a ** 1 -> a
stmt->replace_with(stmt->lhs);
to_erase.push_back(stmt);
modifier.erase(stmt);
} else if (alg_is_zero(rhs)) {
// a ** 0 -> 1
auto one = Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(1));
auto one_raw = one.get();
to_insert_before.emplace_back(std::move(one), stmt);
modifier.insert_before(stmt, std::move(one));
cast_to_result_type(one_raw, stmt);
stmt->replace_with(one_raw);
to_erase.push_back(stmt);
modifier.erase(stmt);
} else if (alg_is_two(rhs)) {
// a ** 2.0 -> a * a
auto a = stmt->lhs;
cast_to_result_type(a, stmt);
auto product = Stmt::make<BinaryOpStmt>(BinaryOpType::mul, a, a);
product->ret_type.data_type = a->ret_type.data_type;
stmt->replace_with(product.get());
to_insert_before.emplace_back(std::move(product), stmt);
to_erase.push_back(stmt);
modifier.insert_before(stmt, std::move(product));
modifier.erase(stmt);
} else if (rhs && is_integral(rhs->ret_type.data_type) &&
((is_signed(rhs->ret_type.data_type) &&
rhs->val[0].val_int() >= 0 &&
Expand All @@ -163,7 +162,7 @@ class AlgSimp : public BasicStmtVisitor {
result, a_power_of_2);
new_result->ret_type.data_type = a->ret_type.data_type;
result = new_result.get();
to_insert_before.emplace_back(std::move(new_result), stmt);
modifier.insert_before(stmt, std::move(new_result));
}
}
current_exponent <<= 1;
Expand All @@ -173,18 +172,18 @@ class AlgSimp : public BasicStmtVisitor {
BinaryOpType::mul, a_power_of_2, a_power_of_2);
new_a_power->ret_type.data_type = a->ret_type.data_type;
a_power_of_2 = new_a_power.get();
to_insert_before.emplace_back(std::move(new_a_power), stmt);
modifier.insert_before(stmt, std::move(new_a_power));
}
stmt->replace_with(result);
to_erase.push_back(stmt);
modifier.erase(stmt);
} else if (rhs && is_integral(rhs->ret_type.data_type) &&
is_signed(rhs->ret_type.data_type) &&
rhs->val[0].val_int() < 0 &&
rhs->val[0].val_int() >= -max_weaken_exponent) {
// a ** -n -> 1 / a ** n
auto one = Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(1));
auto one_raw = one.get();
to_insert_before.emplace_back(std::move(one), stmt);
modifier.insert_before(stmt, std::move(one));
cast_to_result_type(one_raw, stmt);
auto exponent =
Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(-rhs->val[0]));
Expand All @@ -194,20 +193,20 @@ class AlgSimp : public BasicStmtVisitor {
auto result =
Stmt::make<BinaryOpStmt>(BinaryOpType::div, one_raw, a_to_n.get());
stmt->replace_with(result.get());
to_insert_before.emplace_back(std::move(exponent), stmt);
to_insert_before.emplace_back(std::move(a_to_n), stmt);
to_insert_before.emplace_back(std::move(result), stmt);
to_erase.push_back(stmt);
modifier.insert_before(stmt, std::move(exponent));
modifier.insert_before(stmt, std::move(a_to_n));
modifier.insert_before(stmt, std::move(result));
modifier.erase(stmt);
}
} else if (stmt->op_type == BinaryOpType::bit_and) {
if (alg_is_minus_one(rhs)) {
// a & -1 -> a
stmt->replace_with(stmt->lhs);
to_erase.push_back(stmt);
modifier.erase(stmt);
} else if (alg_is_minus_one(lhs)) {
// -1 & a -> a
stmt->replace_with(stmt->rhs);
to_erase.push_back(stmt);
modifier.erase(stmt);
}
}
}
Expand All @@ -218,7 +217,7 @@ class AlgSimp : public BasicStmtVisitor {
return;
if (!alg_is_zero(cond)) {
// this statement has no effect
to_erase.push_back(stmt);
modifier.erase(stmt);
}
}

Expand All @@ -228,7 +227,7 @@ class AlgSimp : public BasicStmtVisitor {
return;
if (!alg_is_zero(cond)) {
// this statement has no effect
to_erase.push_back(stmt);
modifier.erase(stmt);
}
}

Expand Down Expand Up @@ -309,17 +308,10 @@ class AlgSimp : public BasicStmtVisitor {
bool modified = false;
while (true) {
node->accept(&simplifier);
if (simplifier.to_erase.empty() && simplifier.to_insert_before.empty())
if (simplifier.modifier.modify_ir())
modified = true;
else
break;
modified = true;
for (auto &i : simplifier.to_insert_before) {
i.second->insert_before_me(std::move(i.first));
}
for (auto &stmt : simplifier.to_erase) {
stmt->parent->erase(stmt);
}
simplifier.to_insert_before.clear();
simplifier.to_erase.clear();
}
return modified;
}
Expand Down

0 comments on commit 43e1da5

Please sign in to comment.