From 43e1da5d511f6ac949504c80fa242b14fde1cf12 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 1 Jun 2020 20:49:34 -0400 Subject: [PATCH] [opt] Add a DelayedIRModifier class (#1103) * [opt] Add a DelayedIRModifier class * fix tests * add assertions in the destructor --- taichi/ir/ir.cpp | 34 +++++++++++++++ taichi/ir/ir.h | 13 ++++++ taichi/transforms/alg_simp.cpp | 78 +++++++++++++++------------------- 3 files changed, 82 insertions(+), 43 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index d30674c339a96..7169ec106fbd1 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -800,6 +800,40 @@ std::unique_ptr Block::clone() const { return new_block; } +DelayedIRModifier::~DelayedIRModifier() { + TI_ASSERT(to_insert_before.empty()); + TI_ASSERT(to_erase.empty()); +} + +void DelayedIRModifier::erase(Stmt *stmt) { + to_erase.push_back(stmt); +} + +void DelayedIRModifier::insert_before(Stmt *old_statement, + std::unique_ptr new_statements) { + to_insert_before.emplace_back(old_statement, + VecStatement(std::move(new_statements))); +} + +void DelayedIRModifier::insert_before(Stmt *old_statement, + VecStatement &&new_statements) { + to_insert_before.emplace_back(old_statement, std::move(new_statements)); +} + +bool DelayedIRModifier::modify_ir() { + if (to_insert_before.empty() && to_erase.empty()) + return false; + for (auto &i : to_insert_before) { + i.first->parent->insert_before(i.first, std::move(i.second)); + } + to_insert_before.clear(); + for (auto &stmt : to_erase) { + stmt->parent->erase(stmt); + } + to_erase.clear(); + return true; +} + FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, SNode *snode, const ExprGroup &indices, diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 0c052302a117a..018963713dfce 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -872,6 +872,19 @@ class Block : public IRNode { TI_DEFINE_ACCEPT }; +class DelayedIRModifier { + private: + std::vector> to_insert_before; + std::vector to_erase; + + public: + ~DelayedIRModifier(); + void erase(Stmt *stmt); + void insert_before(Stmt *old_statement, std::unique_ptr new_statement); + void insert_before(Stmt *old_statement, VecStatement &&new_statements); + bool modify_ir(); +}; + class SNodeOpStmt : public Stmt { public: SNodeOpType op_type; diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 35bf20a16c2bb..730da0b41bad8 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -14,7 +14,7 @@ class AlgSimp : public BasicStmtVisitor { cast->cast_type = stmt->ret_type.data_type; cast->ret_type.data_type = stmt->ret_type.data_type; a = cast.get(); - to_insert_before.emplace_back(std::move(cast), stmt); + modifier.insert_before(stmt, std::move(cast)); } } @@ -22,8 +22,7 @@ class AlgSimp : public BasicStmtVisitor { static constexpr int max_weaken_exponent = 32; using BasicStmtVisitor::visit; bool fast_math; - std::vector to_erase; - std::vector, Stmt *>> to_insert_before; + DelayedIRModifier modifier; explicit AlgSimp(bool fast_math_) : BasicStmtVisitor(), fast_math(fast_math_) { @@ -44,22 +43,22 @@ class AlgSimp : public BasicStmtVisitor { if (alg_is_zero(rhs)) { // a +-|^ 0 -> a stmt->replace_with(stmt->lhs); - to_erase.push_back(stmt); + modifier.erase(stmt); } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs)) { // 0 +|^ a -> a stmt->replace_with(stmt->rhs); - to_erase.push_back(stmt); + modifier.erase(stmt); } } else if (stmt->op_type == BinaryOpType::mul || stmt->op_type == BinaryOpType::div) { if (alg_is_one(rhs)) { // a */ 1 -> a stmt->replace_with(stmt->lhs); - to_erase.push_back(stmt); + modifier.erase(stmt); } else if (stmt->op_type == BinaryOpType::mul && alg_is_one(lhs)) { // 1 * a -> a stmt->replace_with(stmt->rhs); - to_erase.push_back(stmt); + modifier.erase(stmt); } else if ((fast_math || is_integral(stmt->ret_type.data_type)) && stmt->op_type == BinaryOpType::mul && (alg_is_zero(lhs) || alg_is_zero(rhs))) { @@ -67,17 +66,17 @@ class AlgSimp : public BasicStmtVisitor { if (alg_is_zero(lhs) && lhs->ret_type.data_type == stmt->ret_type.data_type) { stmt->replace_with(stmt->lhs); - to_erase.push_back(stmt); + modifier.erase(stmt); } else if (alg_is_zero(rhs) && rhs->ret_type.data_type == stmt->ret_type.data_type) { stmt->replace_with(stmt->rhs); - to_erase.push_back(stmt); + modifier.erase(stmt); } else { auto zero = Stmt::make( LaneAttribute(stmt->ret_type.data_type)); stmt->replace_with(zero.get()); - to_insert_before.emplace_back(std::move(zero), stmt); - to_erase.push_back(stmt); + modifier.insert_before(stmt, std::move(zero)); + modifier.erase(stmt); } } else if (stmt->op_type == BinaryOpType::mul && (alg_is_two(lhs) || alg_is_two(rhs))) { @@ -89,8 +88,8 @@ class AlgSimp : public BasicStmtVisitor { auto sum = Stmt::make(BinaryOpType::add, a, a); sum->ret_type.data_type = a->ret_type.data_type; stmt->replace_with(sum.get()); - to_insert_before.emplace_back(std::move(sum), stmt); - to_erase.push_back(stmt); + modifier.insert_before(stmt, std::move(sum)); + modifier.erase(stmt); } else if (fast_math && stmt->op_type == BinaryOpType::div && rhs && is_real(rhs->ret_type.data_type)) { if (alg_is_zero(rhs)) { @@ -112,24 +111,24 @@ class AlgSimp : public BasicStmtVisitor { reciprocal.get()); product->ret_type.data_type = stmt->ret_type.data_type; stmt->replace_with(product.get()); - to_insert_before.emplace_back(std::move(reciprocal), stmt); - to_insert_before.emplace_back(std::move(product), stmt); - to_erase.push_back(stmt); + modifier.insert_before(stmt, std::move(reciprocal)); + modifier.insert_before(stmt, std::move(product)); + modifier.erase(stmt); } } } else if (stmt->op_type == BinaryOpType::pow) { if (alg_is_one(rhs)) { // a ** 1 -> a stmt->replace_with(stmt->lhs); - to_erase.push_back(stmt); + modifier.erase(stmt); } else if (alg_is_zero(rhs)) { // a ** 0 -> 1 auto one = Stmt::make(LaneAttribute(1)); auto one_raw = one.get(); - to_insert_before.emplace_back(std::move(one), stmt); + modifier.insert_before(stmt, std::move(one)); cast_to_result_type(one_raw, stmt); stmt->replace_with(one_raw); - to_erase.push_back(stmt); + modifier.erase(stmt); } else if (alg_is_two(rhs)) { // a ** 2.0 -> a * a auto a = stmt->lhs; @@ -137,8 +136,8 @@ class AlgSimp : public BasicStmtVisitor { auto product = Stmt::make(BinaryOpType::mul, a, a); product->ret_type.data_type = a->ret_type.data_type; stmt->replace_with(product.get()); - to_insert_before.emplace_back(std::move(product), stmt); - to_erase.push_back(stmt); + modifier.insert_before(stmt, std::move(product)); + modifier.erase(stmt); } else if (rhs && is_integral(rhs->ret_type.data_type) && ((is_signed(rhs->ret_type.data_type) && rhs->val[0].val_int() >= 0 && @@ -163,7 +162,7 @@ class AlgSimp : public BasicStmtVisitor { result, a_power_of_2); new_result->ret_type.data_type = a->ret_type.data_type; result = new_result.get(); - to_insert_before.emplace_back(std::move(new_result), stmt); + modifier.insert_before(stmt, std::move(new_result)); } } current_exponent <<= 1; @@ -173,10 +172,10 @@ class AlgSimp : public BasicStmtVisitor { BinaryOpType::mul, a_power_of_2, a_power_of_2); new_a_power->ret_type.data_type = a->ret_type.data_type; a_power_of_2 = new_a_power.get(); - to_insert_before.emplace_back(std::move(new_a_power), stmt); + modifier.insert_before(stmt, std::move(new_a_power)); } stmt->replace_with(result); - to_erase.push_back(stmt); + modifier.erase(stmt); } else if (rhs && is_integral(rhs->ret_type.data_type) && is_signed(rhs->ret_type.data_type) && rhs->val[0].val_int() < 0 && @@ -184,7 +183,7 @@ class AlgSimp : public BasicStmtVisitor { // a ** -n -> 1 / a ** n auto one = Stmt::make(LaneAttribute(1)); auto one_raw = one.get(); - to_insert_before.emplace_back(std::move(one), stmt); + modifier.insert_before(stmt, std::move(one)); cast_to_result_type(one_raw, stmt); auto exponent = Stmt::make(LaneAttribute(-rhs->val[0])); @@ -194,20 +193,20 @@ class AlgSimp : public BasicStmtVisitor { auto result = Stmt::make(BinaryOpType::div, one_raw, a_to_n.get()); stmt->replace_with(result.get()); - to_insert_before.emplace_back(std::move(exponent), stmt); - to_insert_before.emplace_back(std::move(a_to_n), stmt); - to_insert_before.emplace_back(std::move(result), stmt); - to_erase.push_back(stmt); + modifier.insert_before(stmt, std::move(exponent)); + modifier.insert_before(stmt, std::move(a_to_n)); + modifier.insert_before(stmt, std::move(result)); + modifier.erase(stmt); } } else if (stmt->op_type == BinaryOpType::bit_and) { if (alg_is_minus_one(rhs)) { // a & -1 -> a stmt->replace_with(stmt->lhs); - to_erase.push_back(stmt); + modifier.erase(stmt); } else if (alg_is_minus_one(lhs)) { // -1 & a -> a stmt->replace_with(stmt->rhs); - to_erase.push_back(stmt); + modifier.erase(stmt); } } } @@ -218,7 +217,7 @@ class AlgSimp : public BasicStmtVisitor { return; if (!alg_is_zero(cond)) { // this statement has no effect - to_erase.push_back(stmt); + modifier.erase(stmt); } } @@ -228,7 +227,7 @@ class AlgSimp : public BasicStmtVisitor { return; if (!alg_is_zero(cond)) { // this statement has no effect - to_erase.push_back(stmt); + modifier.erase(stmt); } } @@ -309,17 +308,10 @@ class AlgSimp : public BasicStmtVisitor { bool modified = false; while (true) { node->accept(&simplifier); - if (simplifier.to_erase.empty() && simplifier.to_insert_before.empty()) + if (simplifier.modifier.modify_ir()) + modified = true; + else break; - modified = true; - for (auto &i : simplifier.to_insert_before) { - i.second->insert_before_me(std::move(i.first)); - } - for (auto &stmt : simplifier.to_erase) { - stmt->parent->erase(stmt); - } - simplifier.to_insert_before.clear(); - simplifier.to_erase.clear(); } return modified; }