diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index f91d35686bfb3..ffffee2338e7a 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -306,6 +306,10 @@ class VecStatement { stmts = std::move(o.stmts); } + VecStatement(std::vector &&other_stmts) { + stmts = std::move(other_stmts); + } + Stmt *push_back(pStmt &&stmt) { auto ret = stmt.get(); stmts.push_back(std::move(stmt)); diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index b99960f36bec1..63ab43afd70ae 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -1186,6 +1186,30 @@ class BasicBlockSimplify : public IRVisitor { if_stmt->parent->erase(if_stmt); throw IRModified(); } + + if (advanced_optimization) { + // Merge adjacent if's with the identical condition. + // TODO: What about IfStmt::true_mask and IfStmt::false_mask? + if (current_stmt_id > 0 && + block->statements[current_stmt_id - 1]->is()) { + auto bstmt = block->statements[current_stmt_id - 1]->as(); + if (bstmt->cond == if_stmt->cond) { + auto concatenate = [](std::unique_ptr &clause1, + std::unique_ptr &clause2) { + if (clause1 == nullptr) { + clause1 = std::move(clause2); + return; + } + if (clause2 != nullptr) + clause1->insert(VecStatement(std::move(clause2->statements))); + }; + concatenate(bstmt->true_statements, if_stmt->true_statements); + concatenate(bstmt->false_statements, if_stmt->false_statements); + if_stmt->parent->erase(if_stmt); + throw IRModified(); + } + } + } } void visit(RangeAssumptionStmt *stmt) override {