Skip to content

Commit

Permalink
[llvm] [refactor] Link modules instead of cloning modules
Browse files Browse the repository at this point in the history
  • Loading branch information
lin-hitonami committed Sep 2, 2022
1 parent 315627a commit fded117
Show file tree
Hide file tree
Showing 21 changed files with 200 additions and 173 deletions.
2 changes: 1 addition & 1 deletion taichi/aot/module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace lang {

void AotModuleBuilder::add(const std::string &identifier, Kernel *kernel) {
if (!kernel->lowered() && Kernel::supports_lowering(kernel->arch)) {
kernel->lower();
kernel->lower(/*to_executable=*/!arch_uses_llvm(kernel->arch));
}
add_per_backend(identifier, kernel);
}
Expand Down
17 changes: 12 additions & 5 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ void KernelCodeGen::cache_module(const std::string &kernel_key,
}

std::vector<LLVMCompiledData> KernelCodeGen::compile_kernel_to_module() {
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);
auto &config = prog->config;
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
Expand All @@ -98,6 +100,7 @@ std::vector<LLVMCompiledData> KernelCodeGen::compile_kernel_to_module() {
TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(),
kernel_key);
cache_module(kernel_key, res);
TI_ASSERT(res.size() == 1);
return res;
}
}
Expand All @@ -110,17 +113,17 @@ std::vector<LLVMCompiledData> KernelCodeGen::compile_kernel_to_module() {
TI_ASSERT(block);

auto &offloads = block->statements;
std::vector<LLVMCompiledData> data(offloads.size());
std::vector<std::unique_ptr<LLVMCompiledData>> data(offloads.size());
using TaskFunc = int32 (*)(void *);
std::vector<TaskFunc> task_funcs(offloads.size());
for (int i = 0; i < offloads.size(); i++) {
auto compile_func = [&, i] {
tlctx->fetch_this_thread_struct_module();
auto offload =
irpass::analysis::clone(offloads[i].get(), offloads[i]->get_kernel());
irpass::re_id(offload.get());
auto new_data = this->compile_task(nullptr, offload->as<OffloadedStmt>());
data[i].tasks = std::move(new_data.tasks);
data[i].module = std::move(new_data.module);
data[i] = std::make_unique<LLVMCompiledData>(std::move(new_data));
};
if (kernel->is_evaluator) {
compile_func();
Expand All @@ -131,11 +134,15 @@ std::vector<LLVMCompiledData> KernelCodeGen::compile_kernel_to_module() {
if (!kernel->is_evaluator) {
worker.flush();
}
auto linked = tlctx->link_compile_data(std::move(data));
std::vector<LLVMCompiledData> linked_data;
linked_data.push_back(std::move(*linked));

if (!kernel->is_evaluator) {
TI_DEBUG("Cache kernel '{}' (key='{}')", kernel->get_name(), kernel_key);
cache_module(kernel_key, data);
cache_module(kernel_key, linked_data);
}
return data;
return linked_data;
}

ModuleToFunctionConverter::ModuleToFunctionConverter(
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class KernelCodeGen {
}

#ifdef TI_WITH_LLVM
std::vector<LLVMCompiledData> compile_kernel_to_module();
virtual std::vector<LLVMCompiledData> compile_kernel_to_module();

virtual LLVMCompiledData compile_task(
std::unique_ptr<llvm::Module> &&module = nullptr,
Expand Down
8 changes: 3 additions & 5 deletions taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,14 @@ FunctionType CPUModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::vector<LLVMCompiledData> &&data) const {
for (auto &datum : data) {
tlctx_->add_module(std::move(datum.module));
}

TI_AUTO_PROF;
auto jit_module = tlctx_->create_jit_module(std::move(data.back().module));
using TaskFunc = int32 (*)(void *);
std::vector<TaskFunc> task_funcs;
task_funcs.reserve(data.size());
for (auto &datum : data) {
for (auto &task : datum.tasks) {
auto *func_ptr = tlctx_->lookup_function_pointer(task.name);
auto *func_ptr = jit_module->lookup_function(task.name);
TI_ASSERT_INFO(func_ptr, "Offloaded datum function {} not found",
task.name);
task_funcs.push_back((TaskFunc)(func_ptr));
Expand Down
39 changes: 16 additions & 23 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,25 +715,20 @@ FunctionType CUDAModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::vector<LLVMCompiledData> &&data) const {
auto &mod = data[0].module;
auto &tasks = data[0].tasks;
#ifdef TI_WITH_CUDA
std::vector<JITModule *> cuda_modules;
std::vector<std::vector<OffloadedTask>> offloaded_tasks;
cuda_modules.reserve(data.size());
for (auto &datum : data) {
auto &mod = datum.module;
auto &tasks = datum.tasks;
for (const auto &task : tasks) {
llvm::Function *func = mod->getFunction(task.name);
TI_ASSERT(func);
tlctx_->mark_function_as_cuda_kernel(func, task.block_dim);
}
auto jit = tlctx_->jit.get();
cuda_modules.push_back(
jit->add_module(std::move(mod), executor_->get_config()->gpu_max_reg));
offloaded_tasks.push_back(std::move(tasks));
for (const auto &task : tasks) {
llvm::Function *func = mod->getFunction(task.name);
TI_ASSERT(func);
tlctx_->mark_function_as_cuda_kernel(func, task.block_dim);
}

return [cuda_modules, kernel_name, args, offloaded_tasks,
auto jit = tlctx_->jit.get();
auto cuda_module = jit->add_module(
std::move(mod), executor_->get_config()->gpu_max_reg);

return [cuda_module, kernel_name, args, offloaded_tasks = tasks,
executor = this->executor_](RuntimeContext &context) {
CUDAContext::get_instance().make_current();
std::vector<void *> arg_buffers(args.size(), nullptr);
Expand Down Expand Up @@ -797,13 +792,11 @@ FunctionType CUDAModuleToFunctionConverter::convert(
CUDADriver::get_instance().stream_synchronize(nullptr);
}

for (int i = 0; i < offloaded_tasks.size(); i++) {
for (auto &task : offloaded_tasks[i]) {
TI_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim,
task.block_dim);
cuda_modules[i]->launch(task.name, task.grid_dim, task.block_dim, 0,
{&context});
}
for (auto task : offloaded_tasks) {
TI_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim,
task.block_dim);
cuda_module->launch(task.name, task.grid_dim, task.block_dim, 0,
{&context});
}

// copy data back to host
Expand Down
25 changes: 14 additions & 11 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,13 @@ void TaskCodeGenLLVM::emit_struct_meta_base(const std::string &name,
// snodes, even if they have the same type.
if (snode->parent)
common.set("from_parent_element",
get_runtime_function(snode->get_ch_from_parent_func_name()));
get_struct_function(snode->get_ch_from_parent_func_name(),
snode->get_snode_tree_id()));

if (snode->type != SNodeType::place)
common.set("refine_coordinates",
get_runtime_function(snode->refine_coordinates_func_name()));
get_struct_function(snode->refine_coordinates_func_name(),
snode->get_snode_tree_id()));
}

TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel,
Expand All @@ -332,7 +334,7 @@ TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel,
: LLVMModuleBuilder(
module == nullptr ? get_llvm_program(kernel->program)
->get_llvm_context(kernel->arch)
->clone_struct_module()
->new_module("kernel")
: std::move(module),
get_llvm_program(kernel->program)->get_llvm_context(kernel->arch)),
kernel(kernel),
Expand Down Expand Up @@ -1706,10 +1708,11 @@ void TaskCodeGenLLVM::visit(GetChStmt *stmt) {
auto offset = tlctx->get_constant(bit_offset);
llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_ptr], offset);
} else {
auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(),
{builder->CreateBitCast(
llvm_val[stmt->input_ptr],
llvm::PointerType::getInt8PtrTy(*llvm_context))});
auto ch = call_struct_func(
stmt->output_snode->get_snode_tree_id(),
stmt->output_snode->get_ch_from_parent_func_name(),
builder->CreateBitCast(llvm_val[stmt->input_ptr],
llvm::PointerType::getInt8PtrTy(*llvm_context)));
llvm_val[stmt] = builder->CreateBitCast(
ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), stmt->output_snode),
Expand Down Expand Up @@ -1989,7 +1992,8 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt,
create_entry_block_alloca(physical_coordinate_ty);

auto refine =
get_runtime_function(leaf_block->refine_coordinates_func_name());
get_struct_function(leaf_block->refine_coordinates_func_name(),
leaf_block->get_snode_tree_id());
// A block corner is the global coordinate/index of the lower-left corner
// cell within that block, and is the same for all the cells within that
// block.
Expand Down Expand Up @@ -2068,8 +2072,8 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt,
// needed to make final coordinates non-consecutive, since each thread will
// process multiple coordinates via vectorization
if (stmt->is_bit_vectorized) {
refine =
get_runtime_function(stmt->snode->refine_coordinates_func_name());
refine = get_struct_function(stmt->snode->refine_coordinates_func_name(),
stmt->snode->get_snode_tree_id());
create_call(refine,
{new_coordinates, new_coordinates, tlctx->get_constant(0)});
}
Expand Down Expand Up @@ -2156,7 +2160,6 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt,
auto struct_for_func = get_runtime_function("parallel_struct_for");

if (arch_is_gpu(current_arch())) {
tlctx->add_struct_for_func(module.get(), stmt->tls_size);
struct_for_func = llvm::cast<llvm::Function>(
module
->getOrInsertFunction(
Expand Down
30 changes: 7 additions & 23 deletions taichi/codegen/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,35 +87,19 @@ class LLVMModuleBuilder {
}

llvm::Type *get_runtime_type(const std::string &name) {
#ifdef TI_LLVM_15
auto ty = llvm::StructType::getTypeByName(module->getContext(),
("struct." + name));
#else
auto ty = module->getTypeByName("struct." + name);
#endif
if (!ty) {
TI_ERROR("LLVMRuntime type {} not found.", name);
}
return ty;
return tlctx->get_runtime_type(name);
}

llvm::Function *get_runtime_function(const std::string &name) {
auto f = module->getFunction(name);
auto f = tlctx->get_runtime_function(name);
if (!f) {
TI_ERROR("LLVMRuntime function {} not found.", name);
}
#ifdef TI_LLVM_15
f->removeFnAttr(llvm::Attribute::OptimizeNone);
f->removeFnAttr(llvm::Attribute::NoInline);
f->addFnAttr(llvm::Attribute::AlwaysInline);
#else
f->removeAttribute(llvm::AttributeList::FunctionIndex,
llvm::Attribute::OptimizeNone);
f->removeAttribute(llvm::AttributeList::FunctionIndex,
llvm::Attribute::NoInline);
f->addAttribute(llvm::AttributeList::FunctionIndex,
llvm::Attribute::AlwaysInline);
#endif
f = llvm::cast<llvm::Function>(
module
->getOrInsertFunction(name, f->getFunctionType(),
f->getAttributes())
.getCallee());
return f;
}

Expand Down
1 change: 1 addition & 0 deletions taichi/codegen/llvm/llvm_compiled_data.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <memory>
#include <unordered_set>

#include "llvm/IR/Module.h"

Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/llvm/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ void StructCompilerLLVM::run(SNode &root) {
auto node_type = get_llvm_node_type(module.get(), &root);
root_size = tlctx_->get_type_size(node_type);

tlctx_->set_struct_module(module);
tlctx_->add_struct_module(std::move(module), root.get_snode_tree_id());
}

llvm::Type *StructCompilerLLVM::get_stub(llvm::Module *module,
Expand Down Expand Up @@ -336,7 +336,6 @@ llvm::Type *StructCompilerLLVM::get_llvm_element_type(llvm::Module *module,

llvm::Function *StructCompilerLLVM::create_function(llvm::FunctionType *ft,
std::string func_name) {
tlctx_->add_function_to_snode_tree(snode_tree_id_, func_name);
return llvm::Function::Create(ft, llvm::Function::ExternalLinkage, func_name,
*module);
}
Expand Down
26 changes: 22 additions & 4 deletions taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "taichi/ir/statements.h"
#include "taichi/util/statistics.h"
#include "taichi/util/file_sequence_writer.h"
#include "taichi/runtime/program_impls/llvm/llvm_program.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -243,10 +244,10 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM {

FunctionType KernelCodeGenWASM::compile_to_function() {
TI_AUTO_PROF
TaskCodeGenWASM gen(kernel, ir);
auto res = gen.run_compilation();
gen.tlctx->add_module(std::move(res.module));
auto kernel_symbol = gen.tlctx->lookup_function_pointer(res.tasks[0].name);
auto linked = std::move(compile_kernel_to_module()[0]);
auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch);
tlctx->create_jit_module(std::move(linked.module));
auto kernel_symbol = tlctx->lookup_function_pointer(linked.tasks[0].name);
return [=](RuntimeContext &context) {
TI_TRACE("Launching Taichi Kernel Function");
auto func = (int32(*)(void *))kernel_symbol;
Expand All @@ -257,6 +258,7 @@ FunctionType KernelCodeGenWASM::compile_to_function() {
LLVMCompiledData KernelCodeGenWASM::compile_task(
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
kernel->offload_to_executable(ir);
bool init_flag = module == nullptr;
std::vector<OffloadedTask> name_list;
auto gen = std::make_unique<TaskCodeGenWASM>(kernel, ir, std::move(module));
Expand All @@ -278,5 +280,21 @@ LLVMCompiledData KernelCodeGenWASM::compile_task(

return {name_list, std::move(gen->module), {}, {}};
}


std::vector<LLVMCompiledData> KernelCodeGenWASM::compile_kernel_to_module() {
auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch);
if (!kernel->lowered()) {
kernel->lower(/*to_executable=*/false);
}
auto res = compile_task();
std::vector<std::unique_ptr<LLVMCompiledData>> data;
data.push_back(std::make_unique<LLVMCompiledData>(std::move(res)));
auto linked = tlctx->link_compile_data(std::move(data));
std::vector<LLVMCompiledData> ret;
ret.push_back(std::move(*linked));
return ret;
}

} // namespace lang
} // namespace taichi
2 changes: 2 additions & 0 deletions taichi/codegen/wasm/codegen_wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class KernelCodeGenWASM : public KernelCodeGen {
LLVMCompiledData compile_task(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override; // AOT Module Gen

std::vector<LLVMCompiledData> compile_kernel_to_module() override;
#endif
};

Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/cpu/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ namespace lang {
namespace cpu {

LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCPU::make_codegen_llvm(kernel, /*ir=*/nullptr);
return cgen->run_compilation();
auto cgen = KernelCodeGenCPU(kernel);
return std::move(cgen.compile_kernel_to_module()[0]);
}

} // namespace cpu
Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/cuda/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ namespace lang {
namespace cuda {

LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCUDA::make_codegen_llvm(kernel, /*ir=*/nullptr);
return cgen->run_compilation();
auto cgen = KernelCodeGenCUDA(kernel);
return std::move(cgen.compile_kernel_to_module()[0]);
}

} // namespace cuda
Expand Down
Loading

0 comments on commit fded117

Please sign in to comment.