From 3ee55d7aa308bfd2edeb29bea330f2dee6c311eb Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Fri, 27 Mar 2020 18:48:24 -0400 Subject: [PATCH] [Opt] Merge adjacent if's with identical conditions (#668) * Merge adjacent if's with the identical condition * [skip ci] enforce code format * simplify using VecStatement * revert taichi/backend/metal/shaders reformatting Co-authored-by: Taichi Gardener Co-authored-by: Yuanming Hu --- taichi/ir/ir.h | 4 ++++ taichi/transforms/simplify.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index f91d35686bfb3..ffffee2338e7a 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -306,6 +306,10 @@ class VecStatement { stmts = std::move(o.stmts); } + VecStatement(std::vector &&other_stmts) { + stmts = std::move(other_stmts); + } + Stmt *push_back(pStmt &&stmt) { auto ret = stmt.get(); stmts.push_back(std::move(stmt)); diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index b99960f36bec1..63ab43afd70ae 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -1186,6 +1186,30 @@ class BasicBlockSimplify : public IRVisitor { if_stmt->parent->erase(if_stmt); throw IRModified(); } + + if (advanced_optimization) { + // Merge adjacent if's with the identical condition. + // TODO: What about IfStmt::true_mask and IfStmt::false_mask? + if (current_stmt_id > 0 && + block->statements[current_stmt_id - 1]->is()) { + auto bstmt = block->statements[current_stmt_id - 1]->as(); + if (bstmt->cond == if_stmt->cond) { + auto concatenate = [](std::unique_ptr &clause1, + std::unique_ptr &clause2) { + if (clause1 == nullptr) { + clause1 = std::move(clause2); + return; + } + if (clause2 != nullptr) + clause1->insert(VecStatement(std::move(clause2->statements))); + }; + concatenate(bstmt->true_statements, if_stmt->true_statements); + concatenate(bstmt->false_statements, if_stmt->false_statements); + if_stmt->parent->erase(if_stmt); + throw IRModified(); + } + } + } } void visit(RangeAssumptionStmt *stmt) override {