diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index d89150f3dcc3c..7dc5b60a33106 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -14,20 +14,18 @@ namespace { // Offloaded local variables to its offset in the global tmps memory. using StmtToOffsetMap = std::unordered_map; -std::unique_ptr> begin_stmt, - end_stmt; +struct OffloadedRanges { + using Map = std::unordered_map; + Map begin_stmts; + Map end_stmts; +}; // Break kernel into multiple parts and emit struct for listgens class Offloader { public: - Offloader(IRNode *root) { - begin_stmt = - std::make_unique>(); - end_stmt = std::make_unique>(); - run(root); - } + static OffloadedRanges run(IRNode *root) { + OffloadedRanges offloaded_ranges; - void run(IRNode *root) { auto root_block = dynamic_cast(root); auto root_statements = std::move(root_block->statements); root_block->statements.clear(); @@ -54,13 +52,15 @@ class Offloader { offloaded->const_begin = true; offloaded->begin_value = val->val[0].val_int32(); } else { - begin_stmt->insert(std::make_pair(offloaded.get(), s->begin)); + offloaded_ranges.begin_stmts.insert( + std::make_pair(offloaded.get(), s->begin)); } if (auto val = s->end->cast()) { offloaded->const_end = true; offloaded->end_value = val->val[0].val_int32(); } else { - end_stmt->insert(std::make_pair(offloaded.get(), s->end)); + offloaded_ranges.end_stmts.insert( + std::make_pair(offloaded.get(), s->end)); } offloaded->block_dim = s->block_dim; offloaded->num_cpu_threads = s->parallelize; @@ -77,9 +77,11 @@ class Offloader { } } assemble_serial_statements(); + return offloaded_ranges; } - void emit_struct_for(StructForStmt *for_stmt, Block *root_block) { + private: + static void emit_struct_for(StructForStmt *for_stmt, Block *root_block) { auto leaf = for_stmt->snode; // make a list of nodes, from the leaf block (instead of 'place') to root std::vector path; @@ -186,8 +188,10 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { private: IdentifyValuesUsedInOtherOffloads( - const std::unordered_map &stmt_to_offloaded) - : stmt_to_offloaded(stmt_to_offloaded) { + const std::unordered_map &stmt_to_offloaded, + OffloadedRanges *offloaded_ranges) + : stmt_to_offloaded(stmt_to_offloaded), + offloaded_ranges_(offloaded_ranges) { allow_undefined_visitor = true; invoke_default_visitor = true; current_offloaded = nullptr; @@ -205,10 +209,12 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { public: void visit(OffloadedStmt *stmt) override { current_offloaded = stmt; - if (auto begin = begin_stmt->find(stmt); begin != begin_stmt->end()) { + if (auto begin = offloaded_ranges_->begin_stmts.find(stmt); + begin != offloaded_ranges_->begin_stmts.end()) { test_and_allocate(begin->second); } - if (auto end = end_stmt->find(stmt); end != end_stmt->end()) { + if (auto end = offloaded_ranges_->end_stmts.find(stmt); + end != offloaded_ranges_->end_stmts.end()) { test_and_allocate(end->second); } if (stmt->body) @@ -267,16 +273,18 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { static StmtToOffsetMap run( IRNode *root, - const std::unordered_map &stmt_to_offloaded) { - IdentifyValuesUsedInOtherOffloads pass(stmt_to_offloaded); + const std::unordered_map &stmt_to_offloaded, + OffloadedRanges *offloaded_ranges) { + IdentifyValuesUsedInOtherOffloads pass(stmt_to_offloaded, offloaded_ranges); root->accept(&pass); return pass.local_to_global; } private: + std::unordered_map stmt_to_offloaded; + OffloadedRanges *const offloaded_ranges_; // Local variables to global temporary offsets (in bytes) StmtToOffsetMap local_to_global; - std::unordered_map stmt_to_offloaded; Stmt *current_offloaded; std::size_t global_offset; }; @@ -331,9 +339,11 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { private: FixCrossOffloadReferences( const StmtToOffsetMap &local_to_global_offset, - std::unordered_map stmt_to_offloaded) + const std::unordered_map &stmt_to_offloaded, + OffloadedRanges *offloaded_ranges) : local_to_global_offset(local_to_global_offset), - stmt_to_offloaded(std::move(stmt_to_offloaded)) { + stmt_to_offloaded(stmt_to_offloaded), + offloaded_ranges_(offloaded_ranges) { allow_undefined_visitor = true; invoke_default_visitor = true; } @@ -344,9 +354,12 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { if (stmt->task_type == OffloadedStmt::TaskType::range_for) { if (!stmt->const_begin) stmt->begin_offset = - local_to_global_offset[begin_stmt->find(stmt)->second]; + local_to_global_offset[offloaded_ranges_->begin_stmts.find(stmt) + ->second]; if (!stmt->const_end) - stmt->end_offset = local_to_global_offset[end_stmt->find(stmt)->second]; + stmt->end_offset = + local_to_global_offset[offloaded_ranges_->end_stmts.find(stmt) + ->second]; } } @@ -480,9 +493,11 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { public: static void run(IRNode *root, - std::unordered_map stmt_to_offloaded, - const StmtToOffsetMap &local_to_global_offset) { - FixCrossOffloadReferences pass(local_to_global_offset, stmt_to_offloaded); + const StmtToOffsetMap &local_to_global_offset, + const std::unordered_map &stmt_to_offloaded, + OffloadedRanges *offloaded_ranges) { + FixCrossOffloadReferences pass(local_to_global_offset, stmt_to_offloaded, + offloaded_ranges); while (true) { try { root->accept(&pass); @@ -495,8 +510,9 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { private: StmtToOffsetMap local_to_global_offset; - std::unordered_map local_to_global_vector_type; std::unordered_map stmt_to_offloaded; + OffloadedRanges *const offloaded_ranges_; + std::unordered_map local_to_global_vector_type; }; void insert_gc(IRNode *root) { @@ -590,17 +606,17 @@ class AssociateContinueScope : public BasicStmtVisitor { } // namespace void offload(IRNode *root) { - Offloader _(root); + auto offloaded_ranges = Offloader::run(root); typecheck(root); fix_block_parents(root); { auto stmt_to_offloaded = StmtToOffloaded::run(root); - const auto local_to_global_offset = - IdentifyValuesUsedInOtherOffloads::run(root, stmt_to_offloaded); + const auto local_to_global_offset = IdentifyValuesUsedInOtherOffloads::run( + root, stmt_to_offloaded, &offloaded_ranges); PromoteIntermediateToGlobalTmp::run(root, local_to_global_offset); stmt_to_offloaded = StmtToOffloaded::run(root); - FixCrossOffloadReferences::run(root, stmt_to_offloaded, - local_to_global_offset); + FixCrossOffloadReferences::run(root, local_to_global_offset, + stmt_to_offloaded, &offloaded_ranges); fix_block_parents(root); } insert_gc(root);