Skip to content

Commit

Permalink
remove struct_module
Browse files Browse the repository at this point in the history
  • Loading branch information
lin-hitonami committed Sep 5, 2022
1 parent fded117 commit ace4cae
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 13 deletions.
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ FunctionType CUDAModuleToFunctionConverter::convert(
}

auto jit = tlctx_->jit.get();
auto cuda_module = jit->add_module(
std::move(mod), executor_->get_config()->gpu_max_reg);
auto cuda_module =
jit->add_module(std::move(mod), executor_->get_config()->gpu_max_reg);

return [cuda_module, kernel_name, args, offloaded_tasks = tasks,
executor = this->executor_](RuntimeContext &context) {
Expand Down
1 change: 0 additions & 1 deletion taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ LLVMCompiledData KernelCodeGenWASM::compile_task(
return {name_list, std::move(gen->module), {}, {}};
}


std::vector<LLVMCompiledData> KernelCodeGenWASM::compile_kernel_to_module() {
auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch);
if (!kernel->lowered()) {
Expand Down
6 changes: 3 additions & 3 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ void TaichiLLVMContext::link_module_with_cuda_libdevice(
}

void TaichiLLVMContext::add_struct_module(std::unique_ptr<Module> module,
int tree_id) {
int tree_id) {
TI_AUTO_PROF;
TI_ASSERT(std::this_thread::get_id() == main_thread_id_);
auto this_thread_data = get_this_thread_data();
Expand Down Expand Up @@ -663,7 +663,8 @@ llvm::DataLayout TaichiLLVMContext::get_data_layout() {
return jit->get_data_layout();
}

JITModule *TaichiLLVMContext::create_jit_module(std::unique_ptr<llvm::Module> module) {
JITModule *TaichiLLVMContext::create_jit_module(
std::unique_ptr<llvm::Module> module) {
return jit->add_module(std::move(module));
}

Expand Down Expand Up @@ -884,7 +885,6 @@ TaichiLLVMContext::ThreadLocalData::ThreadLocalData(

TaichiLLVMContext::ThreadLocalData::~ThreadLocalData() {
runtime_module.reset();
struct_module.reset();
struct_modules.clear();
thread_safe_llvm_context.reset();
}
Expand Down
9 changes: 5 additions & 4 deletions taichi/runtime/llvm/llvm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ class TaichiLLVMContext {
std::unordered_map<int, std::unique_ptr<llvm::Module>> struct_modules;
ThreadLocalData(std::unique_ptr<llvm::orc::ThreadSafeContext> ctx);
~ThreadLocalData();
std::unique_ptr<llvm::Module> struct_module{nullptr}; // TODO: To be
// deleted
};
CompileConfig *config_;
//
This PR roughly does things as follows:
1. Adds `link_context_data` and function `link_compile_data` in `TaichiLLVMContext` for linking.
2. Replace `struct_module` with `struct_modules` in `ThreadLocalData`. The `struct_modules` is a map that stores struct modules of every SNodeTree (The ID of the SNodeTree is the key of the map).

public:
std::unique_ptr<JITSession> jit{nullptr};
Expand Down Expand Up @@ -68,8 +70,7 @@ class TaichiLLVMContext {
*
* @param module Module containing the JIT compiled SNode structs.
*/
void add_struct_module(std::unique_ptr<llvm::Module> module,
int tree_id);
void add_struct_module(std::unique_ptr<llvm::Module> module, int tree_id);

/**
* Clones the LLVM module compiled from llvm/runtime.cpp
Expand Down
9 changes: 6 additions & 3 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,19 @@ std::unique_ptr<StructCompiler> LlvmProgramImpl::compile_snode_tree_types_impl(
auto *const root = tree->root();
std::unique_ptr<StructCompiler> struct_compiler{nullptr};
if (arch_is_cpu(config->arch)) {
auto host_module = runtime_exec_->llvm_context_host_.get()->new_module("struct");
auto host_module =
runtime_exec_->llvm_context_host_.get()->new_module("struct");
struct_compiler = std::make_unique<StructCompilerLLVM>(
host_arch(), this, std::move(host_module), tree->id());
} else if (config->arch == Arch::dx12) {
auto device_module = runtime_exec_->llvm_context_device_.get()->new_module("struct");
auto device_module =
runtime_exec_->llvm_context_device_.get()->new_module("struct");
struct_compiler = std::make_unique<StructCompilerLLVM>(
Arch::dx12, this, std::move(device_module), tree->id());
} else {
TI_ASSERT(config->arch == Arch::cuda);
auto device_module = runtime_exec_->llvm_context_device_.get()->new_module("struct");
auto device_module =
runtime_exec_->llvm_context_device_.get()->new_module("struct");
struct_compiler = std::make_unique<StructCompilerLLVM>(
Arch::cuda, this, std::move(device_module), tree->id());
}
Expand Down

0 comments on commit ace4cae

Please sign in to comment.