Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm] [refactor] Link modules instead of cloning modules #5962

Merged
merged 9 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/aot/module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace lang {

void AotModuleBuilder::add(const std::string &identifier, Kernel *kernel) {
if (!kernel->lowered() && Kernel::supports_lowering(kernel->arch)) {
kernel->lower();
kernel->lower(/*to_executable=*/!arch_uses_llvm(kernel->arch));
}
add_per_backend(identifier, kernel);
}
Expand Down
17 changes: 12 additions & 5 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ void KernelCodeGen::cache_module(const std::string &kernel_key,
}

std::vector<LLVMCompiledData> KernelCodeGen::compile_kernel_to_module() {
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);
auto &config = prog->config;
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
Expand All @@ -98,6 +100,7 @@ std::vector<LLVMCompiledData> KernelCodeGen::compile_kernel_to_module() {
TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(),
kernel_key);
cache_module(kernel_key, res);
TI_ASSERT(res.size() == 1);
return res;
}
}
Expand All @@ -110,17 +113,17 @@ std::vector<LLVMCompiledData> KernelCodeGen::compile_kernel_to_module() {
TI_ASSERT(block);

auto &offloads = block->statements;
std::vector<LLVMCompiledData> data(offloads.size());
std::vector<std::unique_ptr<LLVMCompiledData>> data(offloads.size());
using TaskFunc = int32 (*)(void *);
std::vector<TaskFunc> task_funcs(offloads.size());
for (int i = 0; i < offloads.size(); i++) {
auto compile_func = [&, i] {
tlctx->fetch_this_thread_struct_module();
auto offload =
irpass::analysis::clone(offloads[i].get(), offloads[i]->get_kernel());
irpass::re_id(offload.get());
auto new_data = this->compile_task(nullptr, offload->as<OffloadedStmt>());
data[i].tasks = std::move(new_data.tasks);
data[i].module = std::move(new_data.module);
data[i] = std::make_unique<LLVMCompiledData>(std::move(new_data));
};
if (kernel->is_evaluator) {
compile_func();
Expand All @@ -131,11 +134,15 @@ std::vector<LLVMCompiledData> KernelCodeGen::compile_kernel_to_module() {
if (!kernel->is_evaluator) {
worker.flush();
}
auto linked = tlctx->link_compiled_tasks(std::move(data));
std::vector<LLVMCompiledData> linked_data;
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved
linked_data.push_back(std::move(*linked));

if (!kernel->is_evaluator) {
TI_DEBUG("Cache kernel '{}' (key='{}')", kernel->get_name(), kernel_key);
cache_module(kernel_key, data);
cache_module(kernel_key, linked_data);
}
return data;
return linked_data;
}

ModuleToFunctionConverter::ModuleToFunctionConverter(
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class KernelCodeGen {
}

#ifdef TI_WITH_LLVM
std::vector<LLVMCompiledData> compile_kernel_to_module();
virtual std::vector<LLVMCompiledData> compile_kernel_to_module();

virtual LLVMCompiledData compile_task(
std::unique_ptr<llvm::Module> &&module = nullptr,
Expand Down
8 changes: 3 additions & 5 deletions taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,14 @@ FunctionType CPUModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::vector<LLVMCompiledData> &&data) const {
for (auto &datum : data) {
tlctx_->add_module(std::move(datum.module));
}

TI_AUTO_PROF;
auto jit_module = tlctx_->create_jit_module(std::move(data.back().module));
using TaskFunc = int32 (*)(void *);
std::vector<TaskFunc> task_funcs;
task_funcs.reserve(data.size());
for (auto &datum : data) {
for (auto &task : datum.tasks) {
auto *func_ptr = tlctx_->lookup_function_pointer(task.name);
auto *func_ptr = jit_module->lookup_function(task.name);
TI_ASSERT_INFO(func_ptr, "Offloaded datum function {} not found",
task.name);
task_funcs.push_back((TaskFunc)(func_ptr));
Expand Down
39 changes: 16 additions & 23 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,25 +715,20 @@ FunctionType CUDAModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::vector<LLVMCompiledData> &&data) const {
auto &mod = data[0].module;
auto &tasks = data[0].tasks;
#ifdef TI_WITH_CUDA
std::vector<JITModule *> cuda_modules;
std::vector<std::vector<OffloadedTask>> offloaded_tasks;
cuda_modules.reserve(data.size());
for (auto &datum : data) {
auto &mod = datum.module;
auto &tasks = datum.tasks;
for (const auto &task : tasks) {
llvm::Function *func = mod->getFunction(task.name);
TI_ASSERT(func);
tlctx_->mark_function_as_cuda_kernel(func, task.block_dim);
}
auto jit = tlctx_->jit.get();
cuda_modules.push_back(
jit->add_module(std::move(mod), executor_->get_config()->gpu_max_reg));
offloaded_tasks.push_back(std::move(tasks));
for (const auto &task : tasks) {
llvm::Function *func = mod->getFunction(task.name);
TI_ASSERT(func);
tlctx_->mark_function_as_cuda_kernel(func, task.block_dim);
}

return [cuda_modules, kernel_name, args, offloaded_tasks,
auto jit = tlctx_->jit.get();
auto cuda_module =
jit->add_module(std::move(mod), executor_->get_config()->gpu_max_reg);

return [cuda_module, kernel_name, args, offloaded_tasks = tasks,
executor = this->executor_](RuntimeContext &context) {
CUDAContext::get_instance().make_current();
std::vector<void *> arg_buffers(args.size(), nullptr);
Expand Down Expand Up @@ -797,13 +792,11 @@ FunctionType CUDAModuleToFunctionConverter::convert(
CUDADriver::get_instance().stream_synchronize(nullptr);
}

for (int i = 0; i < offloaded_tasks.size(); i++) {
for (auto &task : offloaded_tasks[i]) {
TI_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim,
task.block_dim);
cuda_modules[i]->launch(task.name, task.grid_dim, task.block_dim, 0,
{&context});
}
for (auto task : offloaded_tasks) {
TI_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim,
task.block_dim);
cuda_module->launch(task.name, task.grid_dim, task.block_dim, 0,
{&context});
}

// copy data back to host
Expand Down
25 changes: 14 additions & 11 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,13 @@ void TaskCodeGenLLVM::emit_struct_meta_base(const std::string &name,
// snodes, even if they have the same type.
if (snode->parent)
common.set("from_parent_element",
get_runtime_function(snode->get_ch_from_parent_func_name()));
get_struct_function(snode->get_ch_from_parent_func_name(),
snode->get_snode_tree_id()));

if (snode->type != SNodeType::place)
common.set("refine_coordinates",
get_runtime_function(snode->refine_coordinates_func_name()));
get_struct_function(snode->refine_coordinates_func_name(),
snode->get_snode_tree_id()));
}

TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel,
Expand All @@ -332,7 +334,7 @@ TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel,
: LLVMModuleBuilder(
module == nullptr ? get_llvm_program(kernel->program)
->get_llvm_context(kernel->arch)
->clone_struct_module()
->new_module("kernel")
: std::move(module),
get_llvm_program(kernel->program)->get_llvm_context(kernel->arch)),
kernel(kernel),
Expand Down Expand Up @@ -1706,10 +1708,11 @@ void TaskCodeGenLLVM::visit(GetChStmt *stmt) {
auto offset = tlctx->get_constant(bit_offset);
llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_ptr], offset);
} else {
auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(),
{builder->CreateBitCast(
llvm_val[stmt->input_ptr],
llvm::PointerType::getInt8PtrTy(*llvm_context))});
auto ch = call_struct_func(
stmt->output_snode->get_snode_tree_id(),
stmt->output_snode->get_ch_from_parent_func_name(),
builder->CreateBitCast(llvm_val[stmt->input_ptr],
llvm::PointerType::getInt8PtrTy(*llvm_context)));
llvm_val[stmt] = builder->CreateBitCast(
ch, llvm::PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), stmt->output_snode),
Expand Down Expand Up @@ -1989,7 +1992,8 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt,
create_entry_block_alloca(physical_coordinate_ty);

auto refine =
get_runtime_function(leaf_block->refine_coordinates_func_name());
get_struct_function(leaf_block->refine_coordinates_func_name(),
leaf_block->get_snode_tree_id());
// A block corner is the global coordinate/index of the lower-left corner
// cell within that block, and is the same for all the cells within that
// block.
Expand Down Expand Up @@ -2068,8 +2072,8 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt,
// needed to make final coordinates non-consecutive, since each thread will
// process multiple coordinates via vectorization
if (stmt->is_bit_vectorized) {
refine =
get_runtime_function(stmt->snode->refine_coordinates_func_name());
refine = get_struct_function(stmt->snode->refine_coordinates_func_name(),
stmt->snode->get_snode_tree_id());
create_call(refine,
{new_coordinates, new_coordinates, tlctx->get_constant(0)});
}
Expand Down Expand Up @@ -2156,7 +2160,6 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt,
auto struct_for_func = get_runtime_function("parallel_struct_for");

if (arch_is_gpu(current_arch())) {
tlctx->add_struct_for_func(module.get(), stmt->tls_size);
struct_for_func = llvm::cast<llvm::Function>(
module
->getOrInsertFunction(
Expand Down
30 changes: 7 additions & 23 deletions taichi/codegen/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,35 +87,19 @@ class LLVMModuleBuilder {
}

llvm::Type *get_runtime_type(const std::string &name) {
#ifdef TI_LLVM_15
auto ty = llvm::StructType::getTypeByName(module->getContext(),
("struct." + name));
#else
auto ty = module->getTypeByName("struct." + name);
#endif
if (!ty) {
TI_ERROR("LLVMRuntime type {} not found.", name);
}
return ty;
return tlctx->get_runtime_type(name);
}

llvm::Function *get_runtime_function(const std::string &name) {
auto f = module->getFunction(name);
auto f = tlctx->get_runtime_function(name);
if (!f) {
TI_ERROR("LLVMRuntime function {} not found.", name);
}
#ifdef TI_LLVM_15
f->removeFnAttr(llvm::Attribute::OptimizeNone);
f->removeFnAttr(llvm::Attribute::NoInline);
f->addFnAttr(llvm::Attribute::AlwaysInline);
#else
f->removeAttribute(llvm::AttributeList::FunctionIndex,
llvm::Attribute::OptimizeNone);
f->removeAttribute(llvm::AttributeList::FunctionIndex,
llvm::Attribute::NoInline);
f->addAttribute(llvm::AttributeList::FunctionIndex,
llvm::Attribute::AlwaysInline);
#endif
f = llvm::cast<llvm::Function>(
module
->getOrInsertFunction(name, f->getFunctionType(),
f->getAttributes())
.getCallee());
return f;
}

Expand Down
1 change: 1 addition & 0 deletions taichi/codegen/llvm/llvm_compiled_data.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <memory>
#include <unordered_set>

#include "llvm/IR/Module.h"

Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/llvm/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ void StructCompilerLLVM::run(SNode &root) {
auto node_type = get_llvm_node_type(module.get(), &root);
root_size = tlctx_->get_type_size(node_type);

tlctx_->set_struct_module(module);
tlctx_->add_struct_module(std::move(module), root.get_snode_tree_id());
}

llvm::Type *StructCompilerLLVM::get_stub(llvm::Module *module,
Expand Down Expand Up @@ -336,7 +336,6 @@ llvm::Type *StructCompilerLLVM::get_llvm_element_type(llvm::Module *module,

llvm::Function *StructCompilerLLVM::create_function(llvm::FunctionType *ft,
std::string func_name) {
tlctx_->add_function_to_snode_tree(snode_tree_id_, func_name);
return llvm::Function::Create(ft, llvm::Function::ExternalLinkage, func_name,
*module);
}
Expand Down
25 changes: 21 additions & 4 deletions taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "taichi/ir/statements.h"
#include "taichi/util/statistics.h"
#include "taichi/util/file_sequence_writer.h"
#include "taichi/runtime/program_impls/llvm/llvm_program.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -243,10 +244,10 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM {

FunctionType KernelCodeGenWASM::compile_to_function() {
TI_AUTO_PROF
TaskCodeGenWASM gen(kernel, ir);
auto res = gen.run_compilation();
gen.tlctx->add_module(std::move(res.module));
auto kernel_symbol = gen.tlctx->lookup_function_pointer(res.tasks[0].name);
auto linked = std::move(compile_kernel_to_module()[0]);
auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch);
tlctx->create_jit_module(std::move(linked.module));
auto kernel_symbol = tlctx->lookup_function_pointer(linked.tasks[0].name);
return [=](RuntimeContext &context) {
TI_TRACE("Launching Taichi Kernel Function");
auto func = (int32(*)(void *))kernel_symbol;
Expand All @@ -257,6 +258,7 @@ FunctionType KernelCodeGenWASM::compile_to_function() {
LLVMCompiledData KernelCodeGenWASM::compile_task(
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
kernel->offload_to_executable(ir);
bool init_flag = module == nullptr;
std::vector<OffloadedTask> name_list;
auto gen = std::make_unique<TaskCodeGenWASM>(kernel, ir, std::move(module));
Expand All @@ -278,5 +280,20 @@ LLVMCompiledData KernelCodeGenWASM::compile_task(

return {name_list, std::move(gen->module), {}, {}};
}

std::vector<LLVMCompiledData> KernelCodeGenWASM::compile_kernel_to_module() {
auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch);
if (!kernel->lowered()) {
kernel->lower(/*to_executable=*/false);
}
auto res = compile_task();
std::vector<std::unique_ptr<LLVMCompiledData>> data;
data.push_back(std::make_unique<LLVMCompiledData>(std::move(res)));
auto linked = tlctx->link_compiled_tasks(std::move(data));
std::vector<LLVMCompiledData> ret;
ret.push_back(std::move(*linked));
return ret;
}

} // namespace lang
} // namespace taichi
2 changes: 2 additions & 0 deletions taichi/codegen/wasm/codegen_wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class KernelCodeGenWASM : public KernelCodeGen {
LLVMCompiledData compile_task(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override; // AOT Module Gen

std::vector<LLVMCompiledData> compile_kernel_to_module() override;
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved
#endif
};

Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/cpu/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ namespace lang {
namespace cpu {

LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCPU::make_codegen_llvm(kernel, /*ir=*/nullptr);
return cgen->run_compilation();
auto cgen = KernelCodeGenCPU(kernel);
return std::move(cgen.compile_kernel_to_module()[0]);
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace cpu
Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/cuda/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ namespace lang {
namespace cuda {

LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCUDA::make_codegen_llvm(kernel, /*ir=*/nullptr);
return cgen->run_compilation();
auto cgen = KernelCodeGenCUDA(kernel);
return std::move(cgen.compile_kernel_to_module()[0]);
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace cuda
Expand Down
Loading