diff --git a/taichi/aot/module_builder.cpp b/taichi/aot/module_builder.cpp index d5f7668a0a430..1b1b71fb58c8f 100644 --- a/taichi/aot/module_builder.cpp +++ b/taichi/aot/module_builder.cpp @@ -6,7 +6,7 @@ namespace lang { void AotModuleBuilder::add(const std::string &identifier, Kernel *kernel) { if (!kernel->lowered() && Kernel::supports_lowering(kernel->arch)) { - kernel->lower(); + kernel->lower(/*to_executable=*/!arch_uses_llvm(kernel->arch)); } add_per_backend(identifier, kernel); } diff --git a/taichi/codegen/codegen.cpp b/taichi/codegen/codegen.cpp index 8faa9bec9d974..181157334115a 100644 --- a/taichi/codegen/codegen.cpp +++ b/taichi/codegen/codegen.cpp @@ -87,6 +87,8 @@ void KernelCodeGen::cache_module(const std::string &kernel_key, } std::vector KernelCodeGen::compile_kernel_to_module() { + auto *llvm_prog = get_llvm_program(prog); + auto *tlctx = llvm_prog->get_llvm_context(kernel->arch); auto &config = prog->config; std::string kernel_key = get_hashed_offline_cache_key(&config, kernel); kernel->set_kernel_key_for_cache(kernel_key); @@ -98,6 +100,7 @@ std::vector KernelCodeGen::compile_kernel_to_module() { TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(), kernel_key); cache_module(kernel_key, res); + TI_ASSERT(res.size() == 1); return res; } } @@ -110,17 +113,17 @@ std::vector KernelCodeGen::compile_kernel_to_module() { TI_ASSERT(block); auto &offloads = block->statements; - std::vector data(offloads.size()); + std::vector> data(offloads.size()); using TaskFunc = int32 (*)(void *); std::vector task_funcs(offloads.size()); for (int i = 0; i < offloads.size(); i++) { auto compile_func = [&, i] { + tlctx->fetch_this_thread_struct_module(); auto offload = irpass::analysis::clone(offloads[i].get(), offloads[i]->get_kernel()); irpass::re_id(offload.get()); auto new_data = this->compile_task(nullptr, offload->as()); - data[i].tasks = std::move(new_data.tasks); - data[i].module = std::move(new_data.module); + data[i] = std::make_unique(std::move(new_data)); }; if (kernel->is_evaluator) { compile_func(); @@ -131,11 +134,15 @@ std::vector KernelCodeGen::compile_kernel_to_module() { if (!kernel->is_evaluator) { worker.flush(); } + auto linked = tlctx->link_compile_data(std::move(data)); + std::vector linked_data; + linked_data.push_back(std::move(*linked)); + if (!kernel->is_evaluator) { TI_DEBUG("Cache kernel '{}' (key='{}')", kernel->get_name(), kernel_key); - cache_module(kernel_key, data); + cache_module(kernel_key, linked_data); } - return data; + return linked_data; } ModuleToFunctionConverter::ModuleToFunctionConverter( diff --git a/taichi/codegen/codegen.h b/taichi/codegen/codegen.h index f2c41ccb0870f..53233ac45d734 100644 --- a/taichi/codegen/codegen.h +++ b/taichi/codegen/codegen.h @@ -33,7 +33,7 @@ class KernelCodeGen { } #ifdef TI_WITH_LLVM - std::vector compile_kernel_to_module(); + virtual std::vector compile_kernel_to_module(); virtual LLVMCompiledData compile_task( std::unique_ptr &&module = nullptr, diff --git a/taichi/codegen/cpu/codegen_cpu.cpp b/taichi/codegen/cpu/codegen_cpu.cpp index e1d1b72143599..ca4742bf51e6a 100644 --- a/taichi/codegen/cpu/codegen_cpu.cpp +++ b/taichi/codegen/cpu/codegen_cpu.cpp @@ -235,16 +235,14 @@ FunctionType CPUModuleToFunctionConverter::convert( const std::string &kernel_name, const std::vector &args, std::vector &&data) const { - for (auto &datum : data) { - tlctx_->add_module(std::move(datum.module)); - } - + TI_AUTO_PROF; + auto jit_module = tlctx_->create_jit_module(std::move(data.back().module)); using TaskFunc = int32 (*)(void *); std::vector task_funcs; task_funcs.reserve(data.size()); for (auto &datum : data) { for (auto &task : datum.tasks) { - auto *func_ptr = tlctx_->lookup_function_pointer(task.name); + auto *func_ptr = jit_module->lookup_function(task.name); TI_ASSERT_INFO(func_ptr, "Offloaded datum function {} not found", task.name); task_funcs.push_back((TaskFunc)(func_ptr)); diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index eb98dfa502444..afc91af6da0bb 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -715,25 +715,20 @@ FunctionType CUDAModuleToFunctionConverter::convert( const std::string &kernel_name, const std::vector &args, std::vector &&data) const { + auto &mod = data[0].module; + auto &tasks = data[0].tasks; #ifdef TI_WITH_CUDA - std::vector cuda_modules; - std::vector> offloaded_tasks; - cuda_modules.reserve(data.size()); - for (auto &datum : data) { - auto &mod = datum.module; - auto &tasks = datum.tasks; - for (const auto &task : tasks) { - llvm::Function *func = mod->getFunction(task.name); - TI_ASSERT(func); - tlctx_->mark_function_as_cuda_kernel(func, task.block_dim); - } - auto jit = tlctx_->jit.get(); - cuda_modules.push_back( - jit->add_module(std::move(mod), executor_->get_config()->gpu_max_reg)); - offloaded_tasks.push_back(std::move(tasks)); + for (const auto &task : tasks) { + llvm::Function *func = mod->getFunction(task.name); + TI_ASSERT(func); + tlctx_->mark_function_as_cuda_kernel(func, task.block_dim); } - return [cuda_modules, kernel_name, args, offloaded_tasks, + auto jit = tlctx_->jit.get(); + auto cuda_module = jit->add_module( + std::move(mod), executor_->get_config()->gpu_max_reg); + + return [cuda_module, kernel_name, args, offloaded_tasks = tasks, executor = this->executor_](RuntimeContext &context) { CUDAContext::get_instance().make_current(); std::vector arg_buffers(args.size(), nullptr); @@ -797,13 +792,11 @@ FunctionType CUDAModuleToFunctionConverter::convert( CUDADriver::get_instance().stream_synchronize(nullptr); } - for (int i = 0; i < offloaded_tasks.size(); i++) { - for (auto &task : offloaded_tasks[i]) { - TI_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, - task.block_dim); - cuda_modules[i]->launch(task.name, task.grid_dim, task.block_dim, 0, - {&context}); - } + for (auto task : offloaded_tasks) { + TI_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, + task.block_dim); + cuda_module->launch(task.name, task.grid_dim, task.block_dim, 0, + {&context}); } // copy data back to host diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 09de38810bfd8..9ab43d0d1722b 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -318,11 +318,13 @@ void TaskCodeGenLLVM::emit_struct_meta_base(const std::string &name, // snodes, even if they have the same type. if (snode->parent) common.set("from_parent_element", - get_runtime_function(snode->get_ch_from_parent_func_name())); + get_struct_function(snode->get_ch_from_parent_func_name(), + snode->get_snode_tree_id())); if (snode->type != SNodeType::place) common.set("refine_coordinates", - get_runtime_function(snode->refine_coordinates_func_name())); + get_struct_function(snode->refine_coordinates_func_name(), + snode->get_snode_tree_id())); } TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel, @@ -332,7 +334,7 @@ TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel, : LLVMModuleBuilder( module == nullptr ? get_llvm_program(kernel->program) ->get_llvm_context(kernel->arch) - ->clone_struct_module() + ->new_module("kernel") : std::move(module), get_llvm_program(kernel->program)->get_llvm_context(kernel->arch)), kernel(kernel), @@ -1706,10 +1708,11 @@ void TaskCodeGenLLVM::visit(GetChStmt *stmt) { auto offset = tlctx->get_constant(bit_offset); llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_ptr], offset); } else { - auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(), - {builder->CreateBitCast( - llvm_val[stmt->input_ptr], - llvm::PointerType::getInt8PtrTy(*llvm_context))}); + auto ch = call_struct_func( + stmt->output_snode->get_snode_tree_id(), + stmt->output_snode->get_ch_from_parent_func_name(), + builder->CreateBitCast(llvm_val[stmt->input_ptr], + llvm::PointerType::getInt8PtrTy(*llvm_context))); llvm_val[stmt] = builder->CreateBitCast( ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type( module.get(), stmt->output_snode), @@ -1989,7 +1992,8 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, create_entry_block_alloca(physical_coordinate_ty); auto refine = - get_runtime_function(leaf_block->refine_coordinates_func_name()); + get_struct_function(leaf_block->refine_coordinates_func_name(), + leaf_block->get_snode_tree_id()); // A block corner is the global coordinate/index of the lower-left corner // cell within that block, and is the same for all the cells within that // block. @@ -2068,8 +2072,8 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, // needed to make final coordinates non-consecutive, since each thread will // process multiple coordinates via vectorization if (stmt->is_bit_vectorized) { - refine = - get_runtime_function(stmt->snode->refine_coordinates_func_name()); + refine = get_struct_function(stmt->snode->refine_coordinates_func_name(), + stmt->snode->get_snode_tree_id()); create_call(refine, {new_coordinates, new_coordinates, tlctx->get_constant(0)}); } @@ -2156,7 +2160,6 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, auto struct_for_func = get_runtime_function("parallel_struct_for"); if (arch_is_gpu(current_arch())) { - tlctx->add_struct_for_func(module.get(), stmt->tls_size); struct_for_func = llvm::cast( module ->getOrInsertFunction( diff --git a/taichi/codegen/llvm/llvm_codegen_utils.h b/taichi/codegen/llvm/llvm_codegen_utils.h index c21c0295e24d9..ec22ed9cdde14 100644 --- a/taichi/codegen/llvm/llvm_codegen_utils.h +++ b/taichi/codegen/llvm/llvm_codegen_utils.h @@ -87,35 +87,19 @@ class LLVMModuleBuilder { } llvm::Type *get_runtime_type(const std::string &name) { -#ifdef TI_LLVM_15 - auto ty = llvm::StructType::getTypeByName(module->getContext(), - ("struct." + name)); -#else - auto ty = module->getTypeByName("struct." + name); -#endif - if (!ty) { - TI_ERROR("LLVMRuntime type {} not found.", name); - } - return ty; + return tlctx->get_runtime_type(name); } llvm::Function *get_runtime_function(const std::string &name) { - auto f = module->getFunction(name); + auto f = tlctx->get_runtime_function(name); if (!f) { TI_ERROR("LLVMRuntime function {} not found.", name); } -#ifdef TI_LLVM_15 - f->removeFnAttr(llvm::Attribute::OptimizeNone); - f->removeFnAttr(llvm::Attribute::NoInline); - f->addFnAttr(llvm::Attribute::AlwaysInline); -#else - f->removeAttribute(llvm::AttributeList::FunctionIndex, - llvm::Attribute::OptimizeNone); - f->removeAttribute(llvm::AttributeList::FunctionIndex, - llvm::Attribute::NoInline); - f->addAttribute(llvm::AttributeList::FunctionIndex, - llvm::Attribute::AlwaysInline); -#endif + f = llvm::cast( + module + ->getOrInsertFunction(name, f->getFunctionType(), + f->getAttributes()) + .getCallee()); return f; } diff --git a/taichi/codegen/llvm/llvm_compiled_data.h b/taichi/codegen/llvm/llvm_compiled_data.h index 0cf52571f1358..59f9a33c3bb42 100644 --- a/taichi/codegen/llvm/llvm_compiled_data.h +++ b/taichi/codegen/llvm/llvm_compiled_data.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include "llvm/IR/Module.h" diff --git a/taichi/codegen/llvm/struct_llvm.cpp b/taichi/codegen/llvm/struct_llvm.cpp index 907342858dfa8..62b986541e060 100644 --- a/taichi/codegen/llvm/struct_llvm.cpp +++ b/taichi/codegen/llvm/struct_llvm.cpp @@ -292,7 +292,7 @@ void StructCompilerLLVM::run(SNode &root) { auto node_type = get_llvm_node_type(module.get(), &root); root_size = tlctx_->get_type_size(node_type); - tlctx_->set_struct_module(module); + tlctx_->add_struct_module(std::move(module), root.get_snode_tree_id()); } llvm::Type *StructCompilerLLVM::get_stub(llvm::Module *module, @@ -336,7 +336,6 @@ llvm::Type *StructCompilerLLVM::get_llvm_element_type(llvm::Module *module, llvm::Function *StructCompilerLLVM::create_function(llvm::FunctionType *ft, std::string func_name) { - tlctx_->add_function_to_snode_tree(snode_tree_id_, func_name); return llvm::Function::Create(ft, llvm::Function::ExternalLinkage, func_name, *module); } diff --git a/taichi/codegen/wasm/codegen_wasm.cpp b/taichi/codegen/wasm/codegen_wasm.cpp index 250179d65d1bd..9f25e0bae202d 100644 --- a/taichi/codegen/wasm/codegen_wasm.cpp +++ b/taichi/codegen/wasm/codegen_wasm.cpp @@ -9,6 +9,7 @@ #include "taichi/ir/statements.h" #include "taichi/util/statistics.h" #include "taichi/util/file_sequence_writer.h" +#include "taichi/runtime/program_impls/llvm/llvm_program.h" namespace taichi { namespace lang { @@ -243,10 +244,10 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM { FunctionType KernelCodeGenWASM::compile_to_function() { TI_AUTO_PROF - TaskCodeGenWASM gen(kernel, ir); - auto res = gen.run_compilation(); - gen.tlctx->add_module(std::move(res.module)); - auto kernel_symbol = gen.tlctx->lookup_function_pointer(res.tasks[0].name); + auto linked = std::move(compile_kernel_to_module()[0]); + auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch); + tlctx->create_jit_module(std::move(linked.module)); + auto kernel_symbol = tlctx->lookup_function_pointer(linked.tasks[0].name); return [=](RuntimeContext &context) { TI_TRACE("Launching Taichi Kernel Function"); auto func = (int32(*)(void *))kernel_symbol; @@ -257,6 +258,7 @@ FunctionType KernelCodeGenWASM::compile_to_function() { LLVMCompiledData KernelCodeGenWASM::compile_task( std::unique_ptr &&module, OffloadedStmt *stmt) { + kernel->offload_to_executable(ir); bool init_flag = module == nullptr; std::vector name_list; auto gen = std::make_unique(kernel, ir, std::move(module)); @@ -278,5 +280,21 @@ LLVMCompiledData KernelCodeGenWASM::compile_task( return {name_list, std::move(gen->module), {}, {}}; } + + +std::vector KernelCodeGenWASM::compile_kernel_to_module() { + auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch); + if (!kernel->lowered()) { + kernel->lower(/*to_executable=*/false); + } + auto res = compile_task(); + std::vector> data; + data.push_back(std::make_unique(std::move(res))); + auto linked = tlctx->link_compile_data(std::move(data)); + std::vector ret; + ret.push_back(std::move(*linked)); + return ret; +} + } // namespace lang } // namespace taichi diff --git a/taichi/codegen/wasm/codegen_wasm.h b/taichi/codegen/wasm/codegen_wasm.h index fde841012314f..91ec5aaf6a648 100644 --- a/taichi/codegen/wasm/codegen_wasm.h +++ b/taichi/codegen/wasm/codegen_wasm.h @@ -23,6 +23,8 @@ class KernelCodeGenWASM : public KernelCodeGen { LLVMCompiledData compile_task( std::unique_ptr &&module = nullptr, OffloadedStmt *stmt = nullptr) override; // AOT Module Gen + + std::vector compile_kernel_to_module() override; #endif }; diff --git a/taichi/runtime/cpu/aot_module_builder_impl.cpp b/taichi/runtime/cpu/aot_module_builder_impl.cpp index c5577cd43337b..daf89298f3d76 100644 --- a/taichi/runtime/cpu/aot_module_builder_impl.cpp +++ b/taichi/runtime/cpu/aot_module_builder_impl.cpp @@ -10,8 +10,8 @@ namespace lang { namespace cpu { LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) { - auto cgen = KernelCodeGenCPU::make_codegen_llvm(kernel, /*ir=*/nullptr); - return cgen->run_compilation(); + auto cgen = KernelCodeGenCPU(kernel); + return std::move(cgen.compile_kernel_to_module()[0]); } } // namespace cpu diff --git a/taichi/runtime/cuda/aot_module_builder_impl.cpp b/taichi/runtime/cuda/aot_module_builder_impl.cpp index 2a40066f3471f..0d7431d9874ad 100644 --- a/taichi/runtime/cuda/aot_module_builder_impl.cpp +++ b/taichi/runtime/cuda/aot_module_builder_impl.cpp @@ -10,8 +10,8 @@ namespace lang { namespace cuda { LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) { - auto cgen = KernelCodeGenCUDA::make_codegen_llvm(kernel, /*ir=*/nullptr); - return cgen->run_compilation(); + auto cgen = KernelCodeGenCUDA(kernel); + return std::move(cgen.compile_kernel_to_module()[0]); } } // namespace cuda diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index c1a1c52a3454d..dea1e5d638d6b 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -112,6 +112,13 @@ TaichiLLVMContext::TaichiLLVMContext(CompileConfig *config, Arch arch) #endif } jit = JITSession::create(this, config, arch); + + link_context_data = std::make_unique( + std::make_unique( + std::make_unique())); + link_context_data->runtime_module = clone_module_to_context( + get_this_thread_runtime_module(), link_context_data->llvm_context); + TI_TRACE("Taichi llvm context created."); } @@ -503,26 +510,18 @@ void TaichiLLVMContext::link_module_with_cuda_libdevice( TI_ERROR("CUDA libdevice linking failure."); } - // Make sure all libdevice functions are linked, and set their linkage to - // internal + // Make sure all libdevice functions are linked for (auto func_name : libdevice_function_names) { auto func = module->getFunction(func_name); if (!func) { TI_INFO("Function {} not found", func_name); - } else - func->setLinkage(llvm::Function::InternalLinkage); + } } } -std::unique_ptr TaichiLLVMContext::clone_struct_module() { - TI_AUTO_PROF - auto struct_module = get_this_thread_struct_module(); - TI_ASSERT(struct_module); - return llvm::CloneModule(*struct_module); -} - -void TaichiLLVMContext::set_struct_module( - const std::unique_ptr &module) { +void TaichiLLVMContext::add_struct_module(std::unique_ptr module, + int tree_id) { + TI_AUTO_PROF; TI_ASSERT(std::this_thread::get_id() == main_thread_id_); auto this_thread_data = get_this_thread_data(); TI_ASSERT(module); @@ -530,16 +529,19 @@ void TaichiLLVMContext::set_struct_module( module->print(llvm::errs(), nullptr); TI_ERROR("module broken"); } - // TODO: Move this after ``if (!arch_is_cpu(arch))``. - this_thread_data->struct_module = llvm::CloneModule(*module); + + link_context_data->struct_modules[tree_id] = + clone_module_to_context(module.get(), link_context_data->llvm_context); + for (auto &[id, data] : per_thread_data_) { if (id == std::this_thread::get_id()) { continue; } - TI_ASSERT(!data->runtime_module); - data->struct_module = clone_module_to_context( - this_thread_data->struct_module.get(), data->llvm_context); + data->struct_modules[tree_id] = + clone_module_to_context(module.get(), data->llvm_context); } + + this_thread_data->struct_modules[tree_id] = std::move(module); } template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, T t) { @@ -661,7 +663,7 @@ llvm::DataLayout TaichiLLVMContext::get_data_layout() { return jit->get_data_layout(); } -JITModule *TaichiLLVMContext::add_module(std::unique_ptr module) { +JITModule *TaichiLLVMContext::create_jit_module(std::unique_ptr module) { return jit->add_module(std::move(module)); } @@ -755,15 +757,6 @@ TaichiLLVMContext::get_this_thread_thread_safe_context() { return data->thread_safe_llvm_context.get(); } -llvm::Module *TaichiLLVMContext::get_this_thread_struct_module() { - ThreadLocalData *data = get_this_thread_data(); - if (!data->struct_module) { - data->struct_module = clone_module_to_this_thread_context( - main_thread_data_->struct_module.get()); - } - return data->struct_module.get(); -} - template llvm::Value *TaichiLLVMContext::get_constant(float32 t); template llvm::Value *TaichiLLVMContext::get_constant(float64 t); @@ -823,24 +816,14 @@ void TaichiLLVMContext::update_runtime_jit_module( return starts_with(func_name, "runtime_") || starts_with(func_name, "LLVMRuntime_"); }); - runtime_jit_module = add_module(std::move(module)); + runtime_jit_module = create_jit_module(std::move(module)); } -void TaichiLLVMContext::delete_functions_of_snode_tree(int id) { - if (!snode_tree_funcs_.count(id)) { - return; +void TaichiLLVMContext::delete_snode_tree(int id) { + TI_ASSERT(link_context_data->struct_modules.erase(id)); + for (auto &[thread_id, data] : per_thread_data_) { + TI_ASSERT(data->struct_modules.erase(id)); } - llvm::Module *module = get_this_thread_struct_module(); - for (auto str : snode_tree_funcs_[id]) { - auto *func = module->getFunction(str); - func->eraseFromParent(); - } - snode_tree_funcs_.erase(id); - set_struct_module(get_this_thread_data()->struct_module); -} - -void TaichiLLVMContext::add_function_to_snode_tree(int id, std::string func) { - snode_tree_funcs_[id].push_back(func); } void TaichiLLVMContext::fetch_this_thread_struct_module() { @@ -906,6 +889,47 @@ TaichiLLVMContext::ThreadLocalData::~ThreadLocalData() { thread_safe_llvm_context.reset(); } +std::unique_ptr TaichiLLVMContext::link_compile_data( + std::vector> data_list) { + auto linked = std::make_unique(); + std::unordered_set used_tree_ids; + std::unordered_set tls_sizes; + std::unordered_set offloaded_names; + auto mod = new_module("kernel", link_context_data->llvm_context); + llvm::Linker linker(*mod); + for (auto &datum : data_list) { + for (auto tree_id : datum->used_tree_ids) { + used_tree_ids.insert(tree_id); + } + for (auto tls_size : datum->struct_for_tls_sizes) { + tls_sizes.insert(tls_size); + } + for (auto &task : datum->tasks) { + offloaded_names.insert(task.name); + linked->tasks.push_back(std::move(task)); + } + linker.linkInModule(clone_module_to_context( + datum->module.get(), link_context_data->llvm_context)); + } + for (auto tree_id : used_tree_ids) { + linker.linkInModule( + llvm::CloneModule(*link_context_data->struct_modules[tree_id]), + llvm::Linker::LinkOnlyNeeded | llvm::Linker::OverrideFromSrc); + } + auto runtime_module = llvm::CloneModule(*link_context_data->runtime_module); + for (auto tls_size : tls_sizes) { + add_struct_for_func(runtime_module.get(), tls_size); + } + linker.linkInModule( + std::move(runtime_module), + llvm::Linker::LinkOnlyNeeded | llvm::Linker::OverrideFromSrc); + eliminate_unused_functions(mod.get(), [&](std::string func_name) -> bool { + return offloaded_names.count(func_name); + }); + linked->module = std::move(mod); + return linked; +} + void TaichiLLVMContext::add_struct_for_func(llvm::Module *module, int tls_size) { // Note that on CUDA local array allocation must have a compile-time diff --git a/taichi/runtime/llvm/llvm_context.h b/taichi/runtime/llvm/llvm_context.h index 6eab588f704e9..4611168f71f0d 100644 --- a/taichi/runtime/llvm/llvm_context.h +++ b/taichi/runtime/llvm/llvm_context.h @@ -13,6 +13,7 @@ #include "taichi/runtime/llvm/llvm_fwd.h" #include "taichi/ir/snode.h" #include "taichi/jit/jit_session.h" +#include "taichi/codegen/llvm/llvm_compiled_data.h" namespace taichi { namespace lang { @@ -43,6 +44,8 @@ class TaichiLLVMContext { // main_thread is defined to be the thread that runs the initializer JITModule *runtime_jit_module{nullptr}; + std::unique_ptr link_context_data{nullptr}; + TaichiLLVMContext(CompileConfig *config, Arch arch); virtual ~TaichiLLVMContext(); @@ -60,19 +63,13 @@ class TaichiLLVMContext { */ void init_runtime_jit_module(); - /** - * Clones the LLVM module containing the JIT compiled SNode structs. - * - * @return The cloned module. - */ - std::unique_ptr clone_struct_module(); - /** * Updates the LLVM module of the JIT compiled SNode structs. * * @param module Module containing the JIT compiled SNode structs. */ - void set_struct_module(const std::unique_ptr &module); + void add_struct_module(std::unique_ptr module, + int tree_id); /** * Clones the LLVM module compiled from llvm/runtime.cpp @@ -83,7 +80,7 @@ class TaichiLLVMContext { std::unique_ptr module_from_file(const std::string &file); - JITModule *add_module(std::unique_ptr module); + JITModule *create_jit_module(std::unique_ptr module); virtual void *lookup_function_pointer(const std::string &name) { return jit->lookup(name); @@ -101,8 +98,6 @@ class TaichiLLVMContext { llvm::Type *get_data_type(DataType dt); - llvm::Module *get_this_thread_struct_module(); - template llvm::Type *get_data_type() { return TaichiLLVMContext::get_data_type(taichi::lang::get_data_type()); @@ -144,14 +139,15 @@ class TaichiLLVMContext { std::string name, llvm::LLVMContext *context = nullptr); - void add_function_to_snode_tree(int id, std::string func); - - void delete_functions_of_snode_tree(int id); + void delete_snode_tree(int id); void add_struct_for_func(llvm::Module *module, int tls_size); static std::string get_struct_for_func_name(int tls_size); + std::unique_ptr link_compile_data( + std::vector> data_list); + private: std::unique_ptr clone_module_to_context( llvm::Module *module, diff --git a/taichi/runtime/llvm/llvm_runtime_executor.cpp b/taichi/runtime/llvm/llvm_runtime_executor.cpp index 494c89d9bc850..576b31ea21646 100644 --- a/taichi/runtime/llvm/llvm_runtime_executor.cpp +++ b/taichi/runtime/llvm/llvm_runtime_executor.cpp @@ -613,8 +613,7 @@ void LlvmRuntimeExecutor::materialize_runtime(MemoryPool *memory_pool, } void LlvmRuntimeExecutor::destroy_snode_tree(SNodeTree *snode_tree) { - get_llvm_context(config_->arch) - ->delete_functions_of_snode_tree(snode_tree->id()); + get_llvm_context(config_->arch)->delete_snode_tree(snode_tree->id()); snode_tree_buffer_manager_->destroy(snode_tree); } diff --git a/taichi/runtime/program_impls/llvm/llvm_program.cpp b/taichi/runtime/program_impls/llvm/llvm_program.cpp index 2f1eb2e451f61..7744fbb83cc59 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.cpp +++ b/taichi/runtime/program_impls/llvm/llvm_program.cpp @@ -42,35 +42,21 @@ FunctionType LlvmProgramImpl::compile(Kernel *kernel, return codegen->compile_to_function(); } -std::unique_ptr -LlvmProgramImpl::clone_struct_compiler_initial_context( - bool has_multiple_snode_trees, - TaichiLLVMContext *tlctx) { - if (has_multiple_snode_trees) { - return tlctx->clone_struct_module(); - } - return tlctx->clone_runtime_module(); -} - std::unique_ptr LlvmProgramImpl::compile_snode_tree_types_impl( SNodeTree *tree) { auto *const root = tree->root(); - const bool has_multiple_snode_trees = (num_snode_trees_processed_ > 0); std::unique_ptr struct_compiler{nullptr}; if (arch_is_cpu(config->arch)) { - auto host_module = clone_struct_compiler_initial_context( - has_multiple_snode_trees, runtime_exec_->llvm_context_host_.get()); + auto host_module = runtime_exec_->llvm_context_host_.get()->new_module("struct"); struct_compiler = std::make_unique( host_arch(), this, std::move(host_module), tree->id()); } else if (config->arch == Arch::dx12) { - auto device_module = clone_struct_compiler_initial_context( - has_multiple_snode_trees, runtime_exec_->llvm_context_device_.get()); + auto device_module = runtime_exec_->llvm_context_device_.get()->new_module("struct"); struct_compiler = std::make_unique( Arch::dx12, this, std::move(device_module), tree->id()); } else { TI_ASSERT(config->arch == Arch::cuda); - auto device_module = clone_struct_compiler_initial_context( - has_multiple_snode_trees, runtime_exec_->llvm_context_device_.get()); + auto device_module = runtime_exec_->llvm_context_device_.get()->new_module("struct"); struct_compiler = std::make_unique( Arch::cuda, this, std::move(device_module), tree->id()); } diff --git a/taichi/runtime/program_impls/llvm/llvm_program.h b/taichi/runtime/program_impls/llvm/llvm_program.h index c84bea18766dc..2da386bc631fa 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.h +++ b/taichi/runtime/program_impls/llvm/llvm_program.h @@ -66,10 +66,6 @@ class LlvmProgramImpl : public ProgramImpl { } private: - std::unique_ptr clone_struct_compiler_initial_context( - bool has_multiple_snode_trees, - TaichiLLVMContext *tlctx); - std::unique_ptr compile_snode_tree_types_impl( SNodeTree *tree); diff --git a/taichi/runtime/wasm/aot_module_builder_impl.cpp b/taichi/runtime/wasm/aot_module_builder_impl.cpp index a4900798df4f2..c0bcacfb357e9 100644 --- a/taichi/runtime/wasm/aot_module_builder_impl.cpp +++ b/taichi/runtime/wasm/aot_module_builder_impl.cpp @@ -4,6 +4,7 @@ #include #include "taichi/util/file_sequence_writer.h" +#include "llvm/Linker/Linker.h" namespace taichi { namespace lang { @@ -36,11 +37,13 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir, void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, Kernel *kernel) { auto module_info = - KernelCodeGenWASM(kernel, nullptr).compile_task(std::move(module_)); - module_ = std::move(module_info.module); - - for (auto &task : module_info.tasks) - name_list_.push_back(task.name); + KernelCodeGenWASM(kernel, nullptr).compile_kernel_to_module(); + if (module_) { + llvm::Linker::linkModules(*module_, std::move(module_info[0].module), + llvm::Linker::OverrideFromSrc); + } else { + module_ = std::move(module_info[0].module); + } } void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, diff --git a/tests/cpp/codegen/refine_coordinates_test.cpp b/tests/cpp/codegen/refine_coordinates_test.cpp index 33399a604d5e5..45a1c44357f89 100644 --- a/tests/cpp/codegen/refine_coordinates_test.cpp +++ b/tests/cpp/codegen/refine_coordinates_test.cpp @@ -2,6 +2,7 @@ #include "gtest/gtest.h" #include +#include #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" @@ -32,14 +33,21 @@ class InvokeRefineCoordinatesBuilder : public LLVMModuleBuilder { static FuncType build(const SNode *snode, TaichiLLVMContext *tlctx) { InvokeRefineCoordinatesBuilder mb{tlctx}; mb.run_jit(snode); - tlctx->add_module(std::move(mb.module)); - auto *fn = tlctx->lookup_function_pointer(kFuncName); + LLVMCompiledData data; + data.module = std::move(mb.module); + data.used_tree_ids = std::move(mb.used_snode_tree_ids); + data.tasks.emplace_back(kFuncName); + std::vector> data_list; + data_list.push_back(std::make_unique(std::move(data))); + auto linked_data = tlctx->link_compile_data(std::move(data_list)); + auto *jit = tlctx->create_jit_module(std::move(linked_data->module)); + auto *fn = jit->lookup_function(kFuncName); return reinterpret_cast(fn); } private: InvokeRefineCoordinatesBuilder(TaichiLLVMContext *tlctx) - : LLVMModuleBuilder(tlctx->clone_struct_module(), tlctx) { + : LLVMModuleBuilder(tlctx->new_module("kernel"), tlctx) { this->llvm_context = this->tlctx->get_this_thread_context(); this->builder = std::make_unique>(*llvm_context); } @@ -75,8 +83,15 @@ class InvokeRefineCoordinatesBuilder : public LLVMModuleBuilder { RuntimeObject parent_coords{kLLVMPhysicalCoordinatesName, this, builder.get()}; parent_coords.set("val", index0, parent_coords_first_component); + auto *refine_fn_struct = tlctx->get_struct_function( + snode->refine_coordinates_func_name(), snode->get_snode_tree_id()); auto *refine_fn = - get_runtime_function(snode->refine_coordinates_func_name()); + module + ->getOrInsertFunction(refine_fn_struct->getName(), + refine_fn_struct->getFunctionType(), + refine_fn_struct->getAttributes()) + .getCallee(); + used_snode_tree_ids.insert(snode->get_snode_tree_id()); RuntimeObject child_coords{kLLVMPhysicalCoordinatesName, this, builder.get()}; builder->CreateCall(refine_fn, @@ -86,6 +101,8 @@ class InvokeRefineCoordinatesBuilder : public LLVMModuleBuilder { llvm::verifyFunction(*func); } + + std::unordered_set used_snode_tree_ids; }; struct BitsRange { @@ -119,9 +136,9 @@ class RefineCoordinatesTest : public ::testing::Test { auto &leaf_snode = dense_snode_->insert_children(SNodeType::place); leaf_snode.dt = PrimitiveType::f32; - auto sc = std::make_unique( - arch_, &config_, tlctx_, tlctx_->clone_runtime_module(), - /*snode_tree_id=*/0); + auto sc = std::make_unique(arch_, &config_, tlctx_, + tlctx_->new_module("struct"), + /*snode_tree_id=*/0); sc->run(*root_snode_); } diff --git a/tests/cpp/llvm/llvm_offline_cache_test.cpp b/tests/cpp/llvm/llvm_offline_cache_test.cpp index 268aad42180ad..ef622db1e6414 100644 --- a/tests/cpp/llvm/llvm_offline_cache_test.cpp +++ b/tests/cpp/llvm/llvm_offline_cache_test.cpp @@ -128,9 +128,10 @@ TEST_P(LlvmOfflineCacheTest, ReadWrite) { ASSERT_NE(kcache.compiled_data_list[0].module, nullptr); kcache.compiled_data_list[0].module->dump(); - tlctx_->add_module(std::move(kcache.compiled_data_list[0].module)); + auto jit_module = tlctx_->create_jit_module( + std::move(kcache.compiled_data_list[0].module)); using FuncType = int (*)(int, int); - FuncType my_add = (FuncType)tlctx_->lookup_function_pointer(kTaskName); + FuncType my_add = (FuncType)jit_module->lookup_function(kTaskName); const auto res = my_add(40, 2); EXPECT_EQ(res, 42); };