From b1bee9f86e894b85bc3d039cbd5010282eb3158d Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sat, 3 Apr 2021 20:26:02 +0800 Subject: [PATCH] [ir] Add RAII guards to IR Builder (#2242) * [ir] Add RAII guards to IR Builder * fix CE and add an optimization * Apply suggestions from code review Co-authored-by: Yuanming Hu * Add get_loop_guard and rename XxxStmt * Add signed conversion * code format * Apply review * Apply review Co-authored-by: Yuanming Hu --- taichi/ir/ir_builder.cpp | 39 +++++++++++++++ taichi/ir/ir_builder.h | 71 ++++++++++++++++++++++------ tests/cpp_new/ir/ir_builder_test.cpp | 37 +++++++++++++-- 3 files changed, 128 insertions(+), 19 deletions(-) diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index 25ef6188555a7..45f0a17e16d77 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -4,6 +4,15 @@ TLANG_NAMESPACE_BEGIN +namespace { + +inline bool stmt_location_did_not_change(Stmt *stmt, int location) { + return location >= 0 && location < stmt->parent->size() && + stmt->parent->statements[location].get() == stmt; +} + +} // namespace + IRBuilder::IRBuilder() { reset(); } @@ -44,6 +53,36 @@ void IRBuilder::set_insertion_point_to_false_branch(IfStmt *if_stmt) { set_insertion_point({if_stmt->false_statements.get(), 0}); } +IRBuilder::LoopGuard::~LoopGuard() { + if (stmt_location_did_not_change(loop_, location_)) { + // faster than set_insertion_point_to_after() + builder_.set_insertion_point({loop_->parent, location_ + 1}); + } else { + builder_.set_insertion_point_to_after(loop_); + } +} + +IRBuilder::IfGuard::IfGuard(IRBuilder &builder, + IfStmt *if_stmt, + bool true_branch) + : builder_(builder), if_stmt_(if_stmt) { + location_ = (int)if_stmt_->parent->size() - 1; + if (true_branch) { + builder_.set_insertion_point_to_true_branch(if_stmt_); + } else { + builder_.set_insertion_point_to_false_branch(if_stmt_); + } +} + +IRBuilder::IfGuard::~IfGuard() { + if (stmt_location_did_not_change(if_stmt_, location_)) { + // faster than set_insertion_point_to_after() + builder_.set_insertion_point({if_stmt_->parent, location_ + 1}); + } else { + builder_.set_insertion_point_to_after(if_stmt_); + } +} + RangeForStmt *IRBuilder::create_range_for(Stmt *begin, Stmt *end, int vectorize, diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index 22ea0e50ca8ed..53581b87cbfd5 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -25,18 +25,18 @@ class IRBuilder { std::unique_ptr extract_ir(); // General inserter. Returns stmt.get(). - template - XxxStmt *insert(std::unique_ptr &&stmt) { + template + XStmt *insert(std::unique_ptr &&stmt) { return insert(std::move(stmt), &insert_point_); } // Insert to a specific insertion point. - template - static XxxStmt *insert(std::unique_ptr &&stmt, - InsertPoint *insert_point) { + template + static XStmt *insert(std::unique_ptr &&stmt, + InsertPoint *insert_point) { return insert_point->block ->insert(std::move(stmt), insert_point->position++) - ->template as(); + ->template as(); } void set_insertion_point(InsertPoint new_insert_point); @@ -44,9 +44,9 @@ class IRBuilder { void set_insertion_point_to_before(Stmt *stmt); void set_insertion_point_to_true_branch(IfStmt *if_stmt); void set_insertion_point_to_false_branch(IfStmt *if_stmt); - template - void set_insertion_point_to_loop_begin(XxxStmt *loop) { - using DecayedType = typename std::decay_t; + template + void set_insertion_point_to_loop_begin(XStmt *loop) { + using DecayedType = typename std::decay_t; if constexpr (!std::is_base_of_v) { TI_ERROR("The argument is not a statement."); } @@ -59,6 +59,47 @@ class IRBuilder { } } + // RAII handles insertion points automatically. + class LoopGuard { + public: + // Set the insertion point to the beginning of the loop body. + template + explicit LoopGuard(IRBuilder &builder, XStmt *loop) + : builder_(builder), loop_(loop) { + location_ = (int)loop->parent->size() - 1; + builder_.set_insertion_point_to_loop_begin(loop); + } + + // Set the insertion point to the point after the loop. + ~LoopGuard(); + + private: + IRBuilder &builder_; + Stmt *loop_; + int location_; + }; + class IfGuard { + public: + // Set the insertion point to the beginning of the true/false branch. + explicit IfGuard(IRBuilder &builder, IfStmt *if_stmt, bool true_branch); + + // Set the insertion point to the point after the if statement. + ~IfGuard(); + + private: + IRBuilder &builder_; + IfStmt *if_stmt_; + int location_; + }; + + template + LoopGuard get_loop_guard(XStmt *loop) { + return LoopGuard(*this, loop); + } + IfGuard get_if_guard(IfStmt *if_stmt, bool true_branch) { + return IfGuard(*this, if_stmt, true_branch); + } + // Control flows. RangeForStmt *create_range_for(Stmt *begin, Stmt *end, @@ -164,9 +205,9 @@ class IRBuilder { const std::vector &indices); ExternalPtrStmt *create_external_ptr(ArgLoadStmt *ptr, const std::vector &indices); - template - GlobalLoadStmt *create_global_load(XxxStmt *ptr) { - using DecayedType = typename std::decay_t; + template + GlobalLoadStmt *create_global_load(XStmt *ptr) { + using DecayedType = typename std::decay_t; if constexpr (!std::is_base_of_v) { TI_ERROR("The argument is not a statement."); } @@ -177,9 +218,9 @@ class IRBuilder { TI_ERROR("Statement {} is not a global pointer.", ptr->name()); } } - template - void create_global_store(XxxStmt *ptr, Stmt *data) { - using DecayedType = typename std::decay_t; + template + void create_global_store(XStmt *ptr, Stmt *data) { + using DecayedType = typename std::decay_t; if constexpr (!std::is_base_of_v) { TI_ERROR("The argument is not a statement."); } diff --git a/tests/cpp_new/ir/ir_builder_test.cpp b/tests/cpp_new/ir/ir_builder_test.cpp index 543859df99794..2202c58de9b68 100644 --- a/tests/cpp_new/ir/ir_builder_test.cpp +++ b/tests/cpp_new/ir/ir_builder_test.cpp @@ -44,10 +44,12 @@ TEST(IRBuilder, RangeFor) { auto *zero = builder.get_int32(0); auto *ten = builder.get_int32(10); auto *loop = builder.create_range_for(zero, ten); - builder.set_insertion_point_to_loop_begin(loop); - auto *index = builder.get_loop_index(loop, 0); - builder.set_insertion_point_to_after(loop); - auto *ret = builder.create_return(zero); + Stmt *index; + { + auto _ = builder.get_loop_guard(loop); + index = builder.get_loop_index(loop, 0); + } + [[maybe_unused]] auto *ret = builder.create_return(zero); EXPECT_EQ(zero->parent->size(), 4); ASSERT_TRUE(loop->is()); auto *loopc = loop->cast(); @@ -55,6 +57,33 @@ TEST(IRBuilder, RangeFor) { EXPECT_EQ(loopc->body->statements[0].get(), index); } +TEST(IRBuilder, LoopGuard) { + IRBuilder builder; + auto *zero = builder.get_int32(0); + auto *ten = builder.get_int32(10); + auto *loop = builder.create_range_for(zero, ten); + Stmt *two; + Stmt *one; + Stmt *sum; + { + auto _ = builder.get_loop_guard(loop); + one = builder.get_int32(1); + builder.set_insertion_point_to_before(loop); + two = builder.get_int32(2); + builder.set_insertion_point_to_after(one); + sum = builder.create_add(one, two); + } + // The insertion point should be after the loop now. + auto *print = builder.create_print(two); + EXPECT_EQ(zero->parent->size(), 5); + EXPECT_EQ(zero->parent->statements[2].get(), two); + EXPECT_EQ(zero->parent->statements[3].get(), loop); + EXPECT_EQ(zero->parent->statements[4].get(), print); + EXPECT_EQ(loop->body->size(), 2); + EXPECT_EQ(loop->body->statements[0].get(), one); + EXPECT_EQ(loop->body->statements[1].get(), sum); +} + TEST(IRBuilder, ExternalPtr) { auto prog = Program(arch_from_name("x64")); prog.materialize_layout();