diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 8fe7e6518f554..1b033660ef6a1 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -2595,15 +2595,9 @@ bool CodeGenLLVM::maybe_read_compilation_from_cache( if (!reader->get_kernel_cache(cache_data, kernel_key, llvm_ctx)) { return false; } - this->module = std::move(cache_data.owned_module); - for (auto &task : cache_data.offloaded_task_list) { - auto &t = this->offloaded_tasks.emplace_back(task.name); - t.block_dim = task.block_dim; - t.grid_dim = task.grid_dim; - } + data->tasks = std::move(cache_data.compiled_data_list[0].tasks); + data->module = std::move(cache_data.compiled_data_list[0].module); kernel->set_from_offline_cache(); - data->tasks = std::move(this->offloaded_tasks); - data->module = std::move(this->module); return true; } @@ -2679,10 +2673,14 @@ void CodeGenLLVM::visit(FuncCallStmt *stmt) { } void CodeGenLLVM::cache_module(const std::string &kernel_key) { - std::vector offloaded_task_list = offloaded_tasks; - get_llvm_program(prog)->cache_kernel(kernel_key, this->module.get(), - infer_launch_args(kernel), - std::move(offloaded_task_list)); + std::vector data; + data.emplace_back(offloaded_tasks, llvm::CloneModule(*module)); + get_llvm_program(prog)->cache_kernel(kernel_key, data, + infer_launch_args(kernel)); +} + +LLVMCompiledData LLVMCompiledData::clone() const { + return {tasks, llvm::CloneModule(*module)}; } TLANG_NAMESPACE_END diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index cb8241f60a51a..3b2328e7b594e 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -51,6 +51,7 @@ struct LLVMCompiledData { std::unique_ptr module) : tasks(std::move(tasks)), module(std::move(module)) { } + LLVMCompiledData clone() const; TI_IO_DEF(tasks); }; diff --git a/taichi/runtime/cpu/aot_module_loader_impl.cpp b/taichi/runtime/cpu/aot_module_loader_impl.cpp index 092eff1040098..8345d373a07ec 100644 --- a/taichi/runtime/cpu/aot_module_loader_impl.cpp +++ b/taichi/runtime/cpu/aot_module_loader_impl.cpp @@ -24,10 +24,8 @@ class AotModuleImpl : public LlvmAotModule { auto *tlctx = executor_->get_llvm_context(arch); CPUModuleToFunctionConverter converter{tlctx, executor_}; - std::vector data; - data.emplace_back(std::move(loaded.offloaded_task_list), - std::move(loaded.owned_module)); - return converter.convert(name, loaded.args, std::move(data)); + return converter.convert(name, loaded.args, + std::move(loaded.compiled_data_list)); } std::unique_ptr make_new_kernel_template( diff --git a/taichi/runtime/cuda/aot_module_loader_impl.cpp b/taichi/runtime/cuda/aot_module_loader_impl.cpp index 65690522d4741..42dd4bc961fbc 100644 --- a/taichi/runtime/cuda/aot_module_loader_impl.cpp +++ b/taichi/runtime/cuda/aot_module_loader_impl.cpp @@ -24,11 +24,8 @@ class AotModuleImpl : public LlvmAotModule { auto *tlctx = executor_->get_llvm_context(arch); CUDAModuleToFunctionConverter converter{tlctx, executor_}; - - std::vector data; - data.emplace_back(std::move(loaded.offloaded_task_list), - std::move(loaded.owned_module)); - return converter.convert(name, loaded.args, std::move(data)); + return converter.convert(name, loaded.args, + std::move(loaded.compiled_data_list)); } std::unique_ptr make_new_kernel_template( diff --git a/taichi/runtime/llvm/llvm_aot_module_builder.cpp b/taichi/runtime/llvm/llvm_aot_module_builder.cpp index 76de10854efeb..9860a93247c7e 100644 --- a/taichi/runtime/llvm/llvm_aot_module_builder.cpp +++ b/taichi/runtime/llvm/llvm_aot_module_builder.cpp @@ -22,10 +22,8 @@ void LlvmAotModuleBuilder::add_per_backend(const std::string &identifier, auto compiled = compile_kernel(kernel); LlvmOfflineCache::KernelCacheData kcache; kcache.kernel_key = identifier; - kcache.module = compiled.module.get(); - kcache.owned_module = std::move(compiled.module); + kcache.compiled_data_list.push_back(std::move(compiled)); kcache.args = infer_launch_args(kernel); - kcache.offloaded_task_list = std::move(compiled.tasks); kcache.last_used_at = std::time(nullptr); kcache.created_at = std::time(nullptr); cache_.kernels[identifier] = std::move(kcache); diff --git a/taichi/runtime/llvm/llvm_offline_cache.cpp b/taichi/runtime/llvm/llvm_offline_cache.cpp index 8f4a017b0c62f..19d12c22f7b97 100644 --- a/taichi/runtime/llvm/llvm_offline_cache.cpp +++ b/taichi/runtime/llvm/llvm_offline_cache.cpp @@ -108,11 +108,16 @@ bool LlvmOfflineCacheFileReader::get_kernel_cache( } auto &kernel_data = itr->second; - if (kernel_data.owned_module == nullptr) { - const std::string filename_prefix = taichi::join_path(path_, key); - kernel_data.owned_module = load_module(filename_prefix, key, llvm_ctx); - TI_ASSERT(kernel_data.owned_module != nullptr); - kernel_data.module = kernel_data.owned_module.get(); + for (int i = 0; i < kernel_data.compiled_data_list.size(); i++) { + auto &data = kernel_data.compiled_data_list[i]; + if (!data.module) { + std::string filename_prefix = + taichi::join_path(path_, key + "." + std::to_string(i)); + data.module = load_module(filename_prefix, key, llvm_ctx); + TI_ASSERT(data.module); + } + res.compiled_data_list.emplace_back(data.tasks, + llvm::CloneModule(*data.module)); } kernel_data.last_used_at = std::time(nullptr); @@ -121,9 +126,6 @@ bool LlvmOfflineCacheFileReader::get_kernel_cache( res.last_used_at = kernel_data.last_used_at; res.kernel_key = key; res.args = kernel_data.args; - res.offloaded_task_list = kernel_data.offloaded_task_list; - res.owned_module = llvm::CloneModule(*kernel_data.module); - res.module = res.owned_module.get(); return true; } @@ -168,22 +170,24 @@ void LlvmOfflineCacheFileWriter::dump(const std::string &path, return llvm_os.tell(); }; { - auto *mod = v.module; - if (!mod) { - mod = v.owned_module.get(); - } - TI_ASSERT(mod != nullptr); - - mangle_offloaded_task_name(k, mod, v.offloaded_task_list); - if (format & Format::LL) { - size += write_llvm_module(".ll", [mod](llvm::raw_os_ostream &os) { - mod->print(os, /*AAW=*/nullptr); - }); - } - if (format & Format::BC) { - size += write_llvm_module(".bc", [mod](llvm::raw_os_ostream &os) { - llvm::WriteBitcodeToFile(*mod, os); - }); + mangle_offloaded_task_name(k, v.compiled_data_list); + for (int i = 0; i < v.compiled_data_list.size(); i++) { + auto &data = v.compiled_data_list[i]; + auto *mod = data.module.get(); + TI_ASSERT(mod != nullptr); + std::string suffix = "." + std::to_string(i); + if (format & Format::LL) { + size += write_llvm_module(suffix + ".ll", + [mod](llvm::raw_os_ostream &os) { + mod->print(os, /*AAW=*/nullptr); + }); + } + if (format & Format::BC) { + size += write_llvm_module(suffix + ".bc", + [mod](llvm::raw_os_ostream &os) { + llvm::WriteBitcodeToFile(*mod, os); + }); + } } } @@ -240,16 +244,18 @@ void LlvmOfflineCacheFileWriter::merge_with(LlvmOfflineCache &&data) { void LlvmOfflineCacheFileWriter::mangle_offloaded_task_name( const std::string &kernel_key, - llvm::Module *module, - std::vector &offloaded_task_list) { + std::vector &compiled_data_list) { if (!mangled_) { std::size_t cnt = 0; - for (auto &e : offloaded_task_list) { - std::string mangled_name = kernel_key + std::to_string(cnt++); - auto func = module->getFunction(e.name); - TI_ASSERT(func != nullptr); - func->setName(mangled_name); - e.name = mangled_name; + for (auto &e : compiled_data_list) { + for (auto &offload : e.tasks) { + std::string mangled_name = kernel_key + std::to_string(cnt++); + + auto func = e.module->getFunction(offload.name); + TI_ASSERT(func != nullptr); + func->setName(mangled_name); + offload.name = mangled_name; + } } } } @@ -314,9 +320,11 @@ void LlvmOfflineCacheFileWriter::clean_cache(const std::string &path, } TI_ASSERT(q.size() <= cnt); while (!q.empty()) { - for (const auto &f : - get_possible_llvm_cache_filename_by_key(q.top().kernel_key)) { - taichi::remove(taichi::join_path(path, f)); + for (int i = 0; i < q.top().compiled_data_list.size(); i++) { + for (const auto &f : get_possible_llvm_cache_filename_by_key( + q.top().kernel_key + "." + std::to_string(i))) { + taichi::remove(taichi::join_path(path, f)); + } } q.pop(); } @@ -343,5 +351,13 @@ LlvmOfflineCacheFileWriter::string_to_clean_cache_policy( return Never; } +LlvmOfflineCache::KernelCacheData LlvmOfflineCache::KernelCacheData::clone() + const { + std::vector new_data_list; + for (const auto &data : compiled_data_list) { + new_data_list.push_back(data.clone()); + } + return {kernel_key, args, std::move(new_data_list)}; +} } // namespace lang } // namespace taichi diff --git a/taichi/runtime/llvm/llvm_offline_cache.h b/taichi/runtime/llvm/llvm_offline_cache.h index b2c02bc734452..8535fbb3abd73 100644 --- a/taichi/runtime/llvm/llvm_offline_cache.h +++ b/taichi/runtime/llvm/llvm_offline_cache.h @@ -23,10 +23,7 @@ struct LlvmOfflineCache { struct KernelCacheData { std::string kernel_key; std::vector args; - std::vector offloaded_task_list; - - std::unique_ptr owned_module{nullptr}; - llvm::Module *module{nullptr}; + std::vector compiled_data_list; // For cache cleaning std::size_t size{0}; // byte @@ -38,9 +35,11 @@ struct LlvmOfflineCache { KernelCacheData &operator=(KernelCacheData &&) = default; ~KernelCacheData() = default; + KernelCacheData clone() const; + TI_IO_DEF(kernel_key, args, - offloaded_task_list, + compiled_data_list, size, created_at, last_used_at); @@ -185,8 +184,7 @@ class LlvmOfflineCacheFileWriter { void mangle_offloaded_task_name( const std::string &kernel_key, - llvm::Module *module, - std::vector &offloaded_task_list); + std::vector &compiled_data_list); LlvmOfflineCache data_; bool mangled_{false}; diff --git a/taichi/runtime/program_impls/llvm/llvm_program.cpp b/taichi/runtime/program_impls/llvm/llvm_program.cpp index c4740854a273a..16c765d417c6a 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.cpp +++ b/taichi/runtime/program_impls/llvm/llvm_program.cpp @@ -107,32 +107,28 @@ std::unique_ptr LlvmProgramImpl::make_aot_kernel(Kernel &kernel) { TI_ASSERT(cache_data_->kernels.count(kernel_key)); const LlvmOfflineCache::KernelCacheData &kernel_data = cache_data_->kernels[kernel_key]; - - LlvmOfflineCache::KernelCacheData compiled_kernel; + LlvmOfflineCache::KernelCacheData compiled_kernel = kernel_data.clone(); compiled_kernel.kernel_key = kernel.get_name(); - compiled_kernel.owned_module = - llvm::CloneModule(*kernel_data.owned_module.get()); - compiled_kernel.args = kernel_data.args; - compiled_kernel.offloaded_task_list = kernel_data.offloaded_task_list; return std::make_unique(compiled_fn, kernel.get_name(), std::move(compiled_kernel)); } void LlvmProgramImpl::cache_kernel( const std::string &kernel_key, - llvm::Module *module, - std::vector &&args, - std::vector &&offloaded_task_list) { + const std::vector &data_list, + std::vector &&args) { if (cache_data_->kernels.find(kernel_key) != cache_data_->kernels.end()) { return; } auto &kernel_cache = cache_data_->kernels[kernel_key]; - kernel_cache.created_at = std::time(nullptr); - kernel_cache.last_used_at = std::time(nullptr); kernel_cache.kernel_key = kernel_key; - kernel_cache.owned_module = llvm::CloneModule(*module); + for (const auto &data : data_list) { + kernel_cache.compiled_data_list.emplace_back( + data.tasks, llvm::CloneModule(*data.module)); + } kernel_cache.args = std::move(args); - kernel_cache.offloaded_task_list = std::move(offloaded_task_list); + kernel_cache.created_at = std::time(nullptr); + kernel_cache.last_used_at = std::time(nullptr); } void LlvmProgramImpl::cache_field(int snode_tree_id, diff --git a/taichi/runtime/program_impls/llvm/llvm_program.h b/taichi/runtime/program_impls/llvm/llvm_program.h index aeb7ff436516c..9b9be65511bc5 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.h +++ b/taichi/runtime/program_impls/llvm/llvm_program.h @@ -51,9 +51,9 @@ class LlvmProgramImpl : public ProgramImpl { void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override; void cache_kernel(const std::string &kernel_key, - llvm::Module *module, - std::vector &&args, - std::vector &&offloaded_task_list); + const std::vector &data_list, + std::vector &&args); + ; void cache_field(int snode_tree_id, int root_id, diff --git a/tests/cpp/llvm/llvm_offline_cache_test.cpp b/tests/cpp/llvm/llvm_offline_cache_test.cpp index 0e2e69937962d..bb669f9d51ae9 100644 --- a/tests/cpp/llvm/llvm_offline_cache_test.cpp +++ b/tests/cpp/llvm/llvm_offline_cache_test.cpp @@ -99,9 +99,13 @@ TEST_P(LlvmOfflineCacheTest, ReadWrite) { kcache.created_at = 1; kcache.last_used_at = 1; kcache.kernel_key = kKernelName; - kcache.owned_module = make_module(*llvm_ctx); - kcache.module = kcache.owned_module.get(); - kcache.offloaded_task_list.emplace_back(kTaskName, kBlockDim, kGridDim); + std::vector tasks; + OffloadedTask task; + task.name = kTaskName; + task.block_dim = kBlockDim; + task.grid_dim = kGridDim; + tasks.push_back(task); + kcache.compiled_data_list.emplace_back(tasks, make_module(*llvm_ctx)); kcache.args = arg_infos; writer.add_kernel_cache(kKernelName, std::move(kcache)); writer.set_no_mangle(); @@ -115,13 +119,13 @@ TEST_P(LlvmOfflineCacheTest, ReadWrite) { const bool ok = reader->get_kernel_cache(kcache, kKernelName, *llvm_ctx); ASSERT_TRUE(ok); EXPECT_EQ(kcache.kernel_key, kKernelName); - EXPECT_EQ(kcache.offloaded_task_list.size(), 1); - const auto &task0 = kcache.offloaded_task_list.front(); + EXPECT_EQ(kcache.compiled_data_list[0].tasks.size(), 1); + const auto &task0 = kcache.compiled_data_list[0].tasks.front(); EXPECT_EQ(task0.name, kTaskName); - ASSERT_NE(kcache.owned_module, nullptr); - kcache.module->dump(); - tlctx_->add_module(std::move(kcache.owned_module)); + ASSERT_NE(kcache.compiled_data_list[0].module, nullptr); + kcache.compiled_data_list[0].module->dump(); + tlctx_->add_module(std::move(kcache.compiled_data_list[0].module)); using FuncType = int (*)(int, int); FuncType my_add = (FuncType)tlctx_->lookup_function_pointer(kTaskName); const auto res = my_add(40, 2);