diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index a1b60f62dffa0..10c5152a99f59 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -42,12 +42,13 @@ std::unique_ptr clone_offloaded_task(OffloadedStmt *from, KernelLaunchRecord::KernelLaunchRecord(Context context, Kernel *kernel, - std::unique_ptr &&stmt_) + std::unique_ptr &&stmt_, + uint64 h) : context(context), kernel(kernel), stmt(dynamic_cast(stmt_.get())), - stmt_holder(std::move(stmt_)), - h(hash(stmt)) { + h(h), + stmt_holder_(std::move(stmt_)) { TI_ASSERT(stmt != nullptr); TI_ASSERT(stmt->get_kernel() != nullptr); } @@ -130,17 +131,27 @@ void AsyncEngine::launch(Kernel *kernel) { kernel->lower(/*to_executable=*/false); auto block = dynamic_cast(kernel->ir.get()); TI_ASSERT(block); + auto &offloads = block->statements; - auto &dummy_root = kernel_to_dummy_roots_[kernel]; - if (dummy_root == nullptr) { - dummy_root = std::make_unique(); - dummy_root->kernel = kernel; + auto &kmeta = kernel_metas_[kernel]; + const bool kmeta_inited = kmeta.initialized(); + if (!kmeta_inited) { + kmeta.dummy_root = std::make_unique(); + kmeta.dummy_root->kernel = kernel; } for (std::size_t i = 0; i < offloads.size(); i++) { auto offload = offloads[i]->as(); - KernelLaunchRecord rec( - kernel->program.get_context(), kernel, - clone_offloaded_task(offload, kernel, dummy_root.get())); + auto cloned = clone_offloaded_task(offload, kernel, kmeta.dummy_root.get()); + uint64 h; + if (kmeta_inited) { + h = kmeta.offloaded_hashes[i]; + } else { + h = hash(cloned.get()); + TI_ASSERT(kmeta.offloaded_hashes.size() == i); + kmeta.offloaded_hashes.push_back(h); + } + KernelLaunchRecord rec(kernel->program.get_context(), kernel, + std::move(cloned), h); enqueue(std::move(rec)); } } @@ -148,7 +159,7 @@ void AsyncEngine::launch(Kernel *kernel) { void AsyncEngine::enqueue(KernelLaunchRecord &&t) { using namespace irpass::analysis; - auto &meta = metas[t.h]; + auto &meta = offloaded_metas_[t.h]; // TODO: this is an abuse since it gathers nothing... auto root_stmt = t.stmt; gather_statements(root_stmt, [&](Stmt *stmt) { @@ -213,7 +224,7 @@ bool AsyncEngine::optimize_listgen() { for (int i = 0; i < task_queue.size(); i++) { // Try to eliminate unused listgens auto &t = task_queue[i]; - auto meta = metas[t.h]; + auto meta = offloaded_metas_[t.h]; auto offload = t.stmt; bool keep = true; if (offload->task_type == OffloadedStmt::TaskType::listgen) { diff --git a/taichi/program/async_engine.h b/taichi/program/async_engine.h index 35e5660355a83..00666c6f1605d 100644 --- a/taichi/program/async_engine.h +++ b/taichi/program/async_engine.h @@ -110,12 +110,15 @@ class KernelLaunchRecord { Context context; Kernel *kernel; // TODO: remove this OffloadedStmt *stmt; - std::unique_ptr stmt_holder; - uint64 h; + uint64 h; // hash of |stmt| - KernelLaunchRecord(Context contxet, + KernelLaunchRecord(Context context, Kernel *kernel, - std::unique_ptr &&stmt); + std::unique_ptr &&stmt, + uint64 h); + + private: + std::unique_ptr stmt_holder_; }; // In charge of (parallel) compilation to binary and (serial) kernel launching @@ -154,13 +157,6 @@ class AsyncEngine { public: // TODO: state machine - struct TaskMeta { - std::unordered_set input_snodes, output_snodes; - std::unordered_set activation_snodes; - }; - - std::unordered_map metas; - ExecutionQueue queue; std::deque task_queue; @@ -183,11 +179,26 @@ class AsyncEngine { void synchronize(); private: + struct KernelMeta { + std::unique_ptr dummy_root; + std::vector offloaded_hashes; + + inline bool initialized() const { + return dummy_root != nullptr; + } + }; + + struct TaskMeta { + std::unordered_set input_snodes, output_snodes; + std::unordered_set activation_snodes; + }; + // In async mode, the root of an AST is an OffloadedStmt instead of a Block. // This map provides a dummy Block root for these OffloadedStmt, so that // get_kernel() could still work correctly. - std::unordered_map> - kernel_to_dummy_roots_; + std::unordered_map kernel_metas_; + + std::unordered_map offloaded_metas_; }; TLANG_NAMESPACE_END