Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Remove the global begin_stmt and end_stmt #1034

Merged
merged 2 commits into from
May 23, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 48 additions & 32 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@ namespace {
// Offloaded local variables to its offset in the global tmps memory.
using StmtToOffsetMap = std::unordered_map<const Stmt *, std::size_t>;

std::unique_ptr<std::unordered_map<OffloadedStmt *, Stmt *>> begin_stmt,
end_stmt;
struct OffloadedRanges {
using Map = std::unordered_map<const OffloadedStmt *, Stmt *>;
Map begin_stmts;
Map end_stmts;
};

// Break kernel into multiple parts and emit struct for listgens
class Offloader {
public:
Offloader(IRNode *root) {
begin_stmt =
std::make_unique<std::unordered_map<OffloadedStmt *, Stmt *>>();
end_stmt = std::make_unique<std::unordered_map<OffloadedStmt *, Stmt *>>();
run(root);
}
static OffloadedRanges run(IRNode *root) {
OffloadedRanges offloaded_ranges;

void run(IRNode *root) {
auto root_block = dynamic_cast<Block *>(root);
auto root_statements = std::move(root_block->statements);
root_block->statements.clear();
Expand All @@ -54,13 +52,15 @@ class Offloader {
offloaded->const_begin = true;
offloaded->begin_value = val->val[0].val_int32();
} else {
begin_stmt->insert(std::make_pair(offloaded.get(), s->begin));
offloaded_ranges.begin_stmts.insert(
std::make_pair(offloaded.get(), s->begin));
}
if (auto val = s->end->cast<ConstStmt>()) {
offloaded->const_end = true;
offloaded->end_value = val->val[0].val_int32();
} else {
end_stmt->insert(std::make_pair(offloaded.get(), s->end));
offloaded_ranges.end_stmts.insert(
std::make_pair(offloaded.get(), s->end));
}
offloaded->block_dim = s->block_dim;
offloaded->num_cpu_threads = s->parallelize;
Expand All @@ -77,9 +77,11 @@ class Offloader {
}
}
assemble_serial_statements();
return offloaded_ranges;
}

void emit_struct_for(StructForStmt *for_stmt, Block *root_block) {
private:
static void emit_struct_for(StructForStmt *for_stmt, Block *root_block) {
auto leaf = for_stmt->snode;
// make a list of nodes, from the leaf block (instead of 'place') to root
std::vector<SNode *> path;
Expand Down Expand Up @@ -186,8 +188,10 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {

private:
IdentifyValuesUsedInOtherOffloads(
const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded)
: stmt_to_offloaded(stmt_to_offloaded) {
const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
OffloadedRanges *offloaded_ranges)
: stmt_to_offloaded(stmt_to_offloaded),
offloaded_ranges_(offloaded_ranges) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
current_offloaded = nullptr;
Expand All @@ -205,10 +209,12 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {
public:
void visit(OffloadedStmt *stmt) override {
current_offloaded = stmt;
if (auto begin = begin_stmt->find(stmt); begin != begin_stmt->end()) {
if (auto begin = offloaded_ranges_->begin_stmts.find(stmt);
begin != offloaded_ranges_->begin_stmts.end()) {
test_and_allocate(begin->second);
}
if (auto end = end_stmt->find(stmt); end != end_stmt->end()) {
if (auto end = offloaded_ranges_->end_stmts.find(stmt);
end != offloaded_ranges_->end_stmts.end()) {
test_and_allocate(end->second);
}
if (stmt->body)
Expand Down Expand Up @@ -267,16 +273,18 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {

static StmtToOffsetMap run(
IRNode *root,
const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded) {
IdentifyValuesUsedInOtherOffloads pass(stmt_to_offloaded);
const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
OffloadedRanges *offloaded_ranges) {
IdentifyValuesUsedInOtherOffloads pass(stmt_to_offloaded, offloaded_ranges);
root->accept(&pass);
return pass.local_to_global;
}

private:
std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded;
OffloadedRanges *const offloaded_ranges_;
// Local variables to global temporary offsets (in bytes)
StmtToOffsetMap local_to_global;
std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded;
Stmt *current_offloaded;
std::size_t global_offset;
};
Expand Down Expand Up @@ -331,9 +339,11 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
private:
FixCrossOffloadReferences(
const StmtToOffsetMap &local_to_global_offset,
std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded)
const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
Comment on lines -334 to +342
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why const & here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. If you look at the entire call stack, it was like:

offload()
\- FixCrossOffloadReferences::run()
    \- FixCrossOffloadReferences::FixCrossOffloadReferences()

Before this PR, only the last call used move, while the first two didn't (likely because we forgot to), so we still made two unnecessary copies. While move is a modern feature, IMO the easiest way to avoid copy remains using const &. (move also means transferring ownership, which is really useful when data cannot be copy constructed, e.g. std::unique_ptr). Does this make sense?

OffloadedRanges *offloaded_ranges)
: local_to_global_offset(local_to_global_offset),
stmt_to_offloaded(std::move(stmt_to_offloaded)) {
stmt_to_offloaded(stmt_to_offloaded),
offloaded_ranges_(offloaded_ranges) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}
Expand All @@ -344,9 +354,12 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
if (stmt->task_type == OffloadedStmt::TaskType::range_for) {
if (!stmt->const_begin)
stmt->begin_offset =
local_to_global_offset[begin_stmt->find(stmt)->second];
local_to_global_offset[offloaded_ranges_->begin_stmts.find(stmt)
->second];
if (!stmt->const_end)
stmt->end_offset = local_to_global_offset[end_stmt->find(stmt)->second];
stmt->end_offset =
local_to_global_offset[offloaded_ranges_->end_stmts.find(stmt)
->second];
}
}

Expand Down Expand Up @@ -480,9 +493,11 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {

public:
static void run(IRNode *root,
std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded,
const StmtToOffsetMap &local_to_global_offset) {
FixCrossOffloadReferences pass(local_to_global_offset, stmt_to_offloaded);
const StmtToOffsetMap &local_to_global_offset,
const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
OffloadedRanges *offloaded_ranges) {
FixCrossOffloadReferences pass(local_to_global_offset, stmt_to_offloaded,
offloaded_ranges);
while (true) {
try {
root->accept(&pass);
Expand All @@ -495,8 +510,9 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {

private:
StmtToOffsetMap local_to_global_offset;
std::unordered_map<Stmt *, VectorType> local_to_global_vector_type;
std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded;
OffloadedRanges *const offloaded_ranges_;
std::unordered_map<Stmt *, VectorType> local_to_global_vector_type;
};

void insert_gc(IRNode *root) {
Expand Down Expand Up @@ -590,17 +606,17 @@ class AssociateContinueScope : public BasicStmtVisitor {
} // namespace

void offload(IRNode *root) {
Offloader _(root);
auto offloaded_ranges = Offloader::run(root);
typecheck(root);
fix_block_parents(root);
{
auto stmt_to_offloaded = StmtToOffloaded::run(root);
const auto local_to_global_offset =
IdentifyValuesUsedInOtherOffloads::run(root, stmt_to_offloaded);
const auto local_to_global_offset = IdentifyValuesUsedInOtherOffloads::run(
root, stmt_to_offloaded, &offloaded_ranges);
PromoteIntermediateToGlobalTmp::run(root, local_to_global_offset);
stmt_to_offloaded = StmtToOffloaded::run(root);
FixCrossOffloadReferences::run(root, stmt_to_offloaded,
local_to_global_offset);
FixCrossOffloadReferences::run(root, local_to_global_offset,
stmt_to_offloaded, &offloaded_ranges);
fix_block_parents(root);
}
insert_gc(root);
Expand Down