Skip to content

Commit

Permalink
[ir] Add RAII guards to IR Builder (#2242)
Browse files Browse the repository at this point in the history
* [ir] Add RAII guards to IR Builder

* fix CE and add an optimization

* Apply suggestions from code review

Co-authored-by: Yuanming Hu <yuanming-hu@users.noreply.github.com>

* Add get_loop_guard and rename XxxStmt

* Add signed conversion

* code format

* Apply review

* Apply review

Co-authored-by: Yuanming Hu <yuanming-hu@users.noreply.github.com>
  • Loading branch information
xumingkuan and yuanming-hu authored Apr 3, 2021
1 parent c74fc2f commit b1bee9f
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 19 deletions.
39 changes: 39 additions & 0 deletions taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

TLANG_NAMESPACE_BEGIN

namespace {

inline bool stmt_location_did_not_change(Stmt *stmt, int location) {
return location >= 0 && location < stmt->parent->size() &&
stmt->parent->statements[location].get() == stmt;
}

} // namespace

IRBuilder::IRBuilder() {
reset();
}
Expand Down Expand Up @@ -44,6 +53,36 @@ void IRBuilder::set_insertion_point_to_false_branch(IfStmt *if_stmt) {
set_insertion_point({if_stmt->false_statements.get(), 0});
}

IRBuilder::LoopGuard::~LoopGuard() {
if (stmt_location_did_not_change(loop_, location_)) {
// faster than set_insertion_point_to_after()
builder_.set_insertion_point({loop_->parent, location_ + 1});
} else {
builder_.set_insertion_point_to_after(loop_);
}
}

IRBuilder::IfGuard::IfGuard(IRBuilder &builder,
IfStmt *if_stmt,
bool true_branch)
: builder_(builder), if_stmt_(if_stmt) {
location_ = (int)if_stmt_->parent->size() - 1;
if (true_branch) {
builder_.set_insertion_point_to_true_branch(if_stmt_);
} else {
builder_.set_insertion_point_to_false_branch(if_stmt_);
}
}

IRBuilder::IfGuard::~IfGuard() {
if (stmt_location_did_not_change(if_stmt_, location_)) {
// faster than set_insertion_point_to_after()
builder_.set_insertion_point({if_stmt_->parent, location_ + 1});
} else {
builder_.set_insertion_point_to_after(if_stmt_);
}
}

RangeForStmt *IRBuilder::create_range_for(Stmt *begin,
Stmt *end,
int vectorize,
Expand Down
71 changes: 56 additions & 15 deletions taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,28 @@ class IRBuilder {
std::unique_ptr<IRNode> extract_ir();

// General inserter. Returns stmt.get().
template <typename XxxStmt>
XxxStmt *insert(std::unique_ptr<XxxStmt> &&stmt) {
template <typename XStmt>
XStmt *insert(std::unique_ptr<XStmt> &&stmt) {
return insert(std::move(stmt), &insert_point_);
}

// Insert to a specific insertion point.
template <typename XxxStmt>
static XxxStmt *insert(std::unique_ptr<XxxStmt> &&stmt,
InsertPoint *insert_point) {
template <typename XStmt>
static XStmt *insert(std::unique_ptr<XStmt> &&stmt,
InsertPoint *insert_point) {
return insert_point->block
->insert(std::move(stmt), insert_point->position++)
->template as<XxxStmt>();
->template as<XStmt>();
}

void set_insertion_point(InsertPoint new_insert_point);
void set_insertion_point_to_after(Stmt *stmt);
void set_insertion_point_to_before(Stmt *stmt);
void set_insertion_point_to_true_branch(IfStmt *if_stmt);
void set_insertion_point_to_false_branch(IfStmt *if_stmt);
template <typename XxxStmt>
void set_insertion_point_to_loop_begin(XxxStmt *loop) {
using DecayedType = typename std::decay_t<XxxStmt>;
template <typename XStmt>
void set_insertion_point_to_loop_begin(XStmt *loop) {
using DecayedType = typename std::decay_t<XStmt>;
if constexpr (!std::is_base_of_v<Stmt, DecayedType>) {
TI_ERROR("The argument is not a statement.");
}
Expand All @@ -59,6 +59,47 @@ class IRBuilder {
}
}

// RAII handles insertion points automatically.
class LoopGuard {
public:
// Set the insertion point to the beginning of the loop body.
template <typename XStmt>
explicit LoopGuard(IRBuilder &builder, XStmt *loop)
: builder_(builder), loop_(loop) {
location_ = (int)loop->parent->size() - 1;
builder_.set_insertion_point_to_loop_begin(loop);
}

// Set the insertion point to the point after the loop.
~LoopGuard();

private:
IRBuilder &builder_;
Stmt *loop_;
int location_;
};
class IfGuard {
public:
// Set the insertion point to the beginning of the true/false branch.
explicit IfGuard(IRBuilder &builder, IfStmt *if_stmt, bool true_branch);

// Set the insertion point to the point after the if statement.
~IfGuard();

private:
IRBuilder &builder_;
IfStmt *if_stmt_;
int location_;
};

template <typename XStmt>
LoopGuard get_loop_guard(XStmt *loop) {
return LoopGuard(*this, loop);
}
IfGuard get_if_guard(IfStmt *if_stmt, bool true_branch) {
return IfGuard(*this, if_stmt, true_branch);
}

// Control flows.
RangeForStmt *create_range_for(Stmt *begin,
Stmt *end,
Expand Down Expand Up @@ -164,9 +205,9 @@ class IRBuilder {
const std::vector<Stmt *> &indices);
ExternalPtrStmt *create_external_ptr(ArgLoadStmt *ptr,
const std::vector<Stmt *> &indices);
template <typename XxxStmt>
GlobalLoadStmt *create_global_load(XxxStmt *ptr) {
using DecayedType = typename std::decay_t<XxxStmt>;
template <typename XStmt>
GlobalLoadStmt *create_global_load(XStmt *ptr) {
using DecayedType = typename std::decay_t<XStmt>;
if constexpr (!std::is_base_of_v<Stmt, DecayedType>) {
TI_ERROR("The argument is not a statement.");
}
Expand All @@ -177,9 +218,9 @@ class IRBuilder {
TI_ERROR("Statement {} is not a global pointer.", ptr->name());
}
}
template <typename XxxStmt>
void create_global_store(XxxStmt *ptr, Stmt *data) {
using DecayedType = typename std::decay_t<XxxStmt>;
template <typename XStmt>
void create_global_store(XStmt *ptr, Stmt *data) {
using DecayedType = typename std::decay_t<XStmt>;
if constexpr (!std::is_base_of_v<Stmt, DecayedType>) {
TI_ERROR("The argument is not a statement.");
}
Expand Down
37 changes: 33 additions & 4 deletions tests/cpp_new/ir/ir_builder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,46 @@ TEST(IRBuilder, RangeFor) {
auto *zero = builder.get_int32(0);
auto *ten = builder.get_int32(10);
auto *loop = builder.create_range_for(zero, ten);
builder.set_insertion_point_to_loop_begin(loop);
auto *index = builder.get_loop_index(loop, 0);
builder.set_insertion_point_to_after(loop);
auto *ret = builder.create_return(zero);
Stmt *index;
{
auto _ = builder.get_loop_guard(loop);
index = builder.get_loop_index(loop, 0);
}
[[maybe_unused]] auto *ret = builder.create_return(zero);
EXPECT_EQ(zero->parent->size(), 4);
ASSERT_TRUE(loop->is<RangeForStmt>());
auto *loopc = loop->cast<RangeForStmt>();
EXPECT_EQ(loopc->body->size(), 1);
EXPECT_EQ(loopc->body->statements[0].get(), index);
}

TEST(IRBuilder, LoopGuard) {
IRBuilder builder;
auto *zero = builder.get_int32(0);
auto *ten = builder.get_int32(10);
auto *loop = builder.create_range_for(zero, ten);
Stmt *two;
Stmt *one;
Stmt *sum;
{
auto _ = builder.get_loop_guard(loop);
one = builder.get_int32(1);
builder.set_insertion_point_to_before(loop);
two = builder.get_int32(2);
builder.set_insertion_point_to_after(one);
sum = builder.create_add(one, two);
}
// The insertion point should be after the loop now.
auto *print = builder.create_print(two);
EXPECT_EQ(zero->parent->size(), 5);
EXPECT_EQ(zero->parent->statements[2].get(), two);
EXPECT_EQ(zero->parent->statements[3].get(), loop);
EXPECT_EQ(zero->parent->statements[4].get(), print);
EXPECT_EQ(loop->body->size(), 2);
EXPECT_EQ(loop->body->statements[0].get(), one);
EXPECT_EQ(loop->body->statements[1].get(), sum);
}

TEST(IRBuilder, ExternalPtr) {
auto prog = Program(arch_from_name("x64"));
prog.materialize_layout();
Expand Down

0 comments on commit b1bee9f

Please sign in to comment.