Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm] [refactor] (Decomp of #5251 6/n) Let ModuleToFunctionConverter support multiple modules #5372

Merged
merged 4 commits into from
Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,19 @@ std::unique_ptr<KernelCodeGen> KernelCodeGen::create(Arch arch,
TI_ERROR("Llvm disabled");
#endif
}
#ifdef TI_WITH_LLVM

ModuleToFunctionConverter::ModuleToFunctionConverter(
TaichiLLVMContext *tlctx,
LlvmRuntimeExecutor *executor)
: tlctx_(tlctx), executor_(executor) {
}

FunctionType ModuleToFunctionConverter::convert(
const Kernel *kernel,
std::vector<LLVMCompiledData> &&data) const {
return convert(kernel->name, infer_launch_args(kernel), std::move(data));
}

#endif
TLANG_NAMESPACE_END
23 changes: 23 additions & 0 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Driver class for kernel code generators.

#pragma once
#include <taichi/runtime/llvm/llvm_runtime_executor.h>
#include "taichi/ir/ir.h"
#include "taichi/program/program.h"
#ifdef TI_WITH_LLVM
Expand Down Expand Up @@ -36,4 +37,26 @@ class KernelCodeGen {
#endif
};

#ifdef TI_WITH_LLVM

class ModuleToFunctionConverter {
public:
explicit ModuleToFunctionConverter(TaichiLLVMContext *tlctx,
LlvmRuntimeExecutor *program);

virtual ~ModuleToFunctionConverter() = default;

virtual FunctionType convert(const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::vector<LLVMCompiledData> &&data) const = 0;

virtual FunctionType convert(const Kernel *kernel,
std::vector<LLVMCompiledData> &&data) const;

protected:
TaichiLLVMContext *tlctx_{nullptr};
LlvmRuntimeExecutor *executor_{nullptr};
};

#endif
TLANG_NAMESPACE_END
55 changes: 55 additions & 0 deletions taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
TI_NOT_IMPLEMENTED
}
}

FunctionType gen() override {
auto compiled_res = run_compilation();

CPUModuleToFunctionConverter converter{
tlctx, get_llvm_program(prog)->get_runtime_executor()};
std::vector<LLVMCompiledData> data;
data.push_back(std::move(compiled_res));
return converter.convert(kernel, std::move(data));
}
};

} // namespace
Expand All @@ -219,6 +229,51 @@ std::unique_ptr<CodeGenLLVM> CodeGenCPU::make_codegen_llvm(Kernel *kernel,
IRNode *ir) {
return std::make_unique<CodeGenLLVMCPU>(kernel, ir);
}

FunctionType CPUModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::vector<LLVMCompiledData> &&data) const {
for (auto &datum : data) {
tlctx_->add_module(std::move(datum.module));
}

using TaskFunc = int32 (*)(void *);
std::vector<TaskFunc> task_funcs;
task_funcs.reserve(data.size());
for (auto &datum : data) {
for (auto &task : datum.tasks) {
auto *func_ptr = tlctx_->lookup_function_pointer(task.name);
TI_ASSERT_INFO(func_ptr, "Offloaded datum function {} not found",
task.name);
task_funcs.push_back((TaskFunc)(func_ptr));
}
}
// Do NOT capture `this`...
return [executor = this->executor_, args, kernel_name,
task_funcs](RuntimeContext &context) {
TI_TRACE("Launching kernel {}", kernel_name);
// For taichi ndarrays, context.args saves pointer to its
// |DeviceAllocation|, CPU backend actually want to use the raw ptr here.
for (int i = 0; i < (int)args.size(); i++) {
if (args[i].is_array &&
context.device_allocation_type[i] !=
RuntimeContext::DevAllocType::kNone &&
context.array_runtime_sizes[i] > 0) {
DeviceAllocation *ptr =
static_cast<DeviceAllocation *>(context.get_arg<void *>(i));
uint64 host_ptr = (uint64)executor->get_ndarray_alloc_info_ptr(*ptr);
context.set_arg(i, host_ptr);
context.set_array_device_allocation_type(
i, RuntimeContext::DevAllocType::kNone);
}
}
for (auto task : task_funcs) {
task(&context);
}
};
}

#endif // TI_WITH_LLVM

FunctionType CodeGenCPU::codegen() {
Expand Down
18 changes: 18 additions & 0 deletions taichi/codegen/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,22 @@ class CodeGenCPU : public KernelCodeGen {
FunctionType codegen() override;
};

#ifdef TI_WITH_LLVM

class CPUModuleToFunctionConverter : public ModuleToFunctionConverter {
public:
explicit CPUModuleToFunctionConverter(TaichiLLVMContext *tlctx,
LlvmRuntimeExecutor *executor)
: ModuleToFunctionConverter(tlctx, executor) {
}

using ModuleToFunctionConverter::convert;

FunctionType convert(const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::vector<LLVMCompiledData> &&data) const override;
};

#endif

TLANG_NAMESPACE_END
21 changes: 9 additions & 12 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
auto *llvm_prog = get_llvm_program(kernel->program);
CUDAModuleToFunctionConverter converter{tlctx,
llvm_prog->get_runtime_executor()};

return converter.convert(this->kernel, std::move(compiled_res.module),
std::move(compiled_res.tasks));
std::vector<LLVMCompiledData> data;
data.push_back(std::move(compiled_res));
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved
return converter.convert(this->kernel, std::move(data));
}

llvm::Value *create_print(std::string tag,
Expand Down Expand Up @@ -738,11 +738,14 @@ FunctionType CodeGenCUDA::codegen() {
return CodeGenLLVMCUDA(kernel, ir).gen();
}

#ifdef TI_WITH_LLVM
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved

FunctionType CUDAModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::unique_ptr<llvm::Module> mod,
std::vector<OffloadedTask> &&tasks) const {
std::vector<LLVMCompiledData> &&data) const {
auto &mod = data[0].module;
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved
auto &tasks = data[0].tasks;
#ifdef TI_WITH_CUDA
for (const auto &task : tasks) {
llvm::Function *func = mod->getFunction(task.name);
Expand Down Expand Up @@ -847,12 +850,6 @@ FunctionType CUDAModuleToFunctionConverter::convert(
#endif // TI_WITH_CUDA
}

FunctionType CUDAModuleToFunctionConverter::convert(
const Kernel *kernel,
std::unique_ptr<llvm::Module> mod,
std::vector<OffloadedTask> &&tasks) const {
return convert(kernel->name, infer_launch_args(kernel), std::move(mod),
std::move(tasks));
}
#endif

TLANG_NAMESPACE_END
12 changes: 6 additions & 6 deletions taichi/codegen/cuda/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ class CodeGenCUDA : public KernelCodeGen {
FunctionType codegen() override;
};

#ifdef TI_WITH_LLVM
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved

class CUDAModuleToFunctionConverter : public ModuleToFunctionConverter {
public:
explicit CUDAModuleToFunctionConverter(TaichiLLVMContext *tlctx,
LlvmRuntimeExecutor *executor)
: ModuleToFunctionConverter(tlctx, executor) {
}
using ModuleToFunctionConverter::convert;

FunctionType convert(const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::unique_ptr<llvm::Module> mod,
std::vector<OffloadedTask> &&tasks) const override;

FunctionType convert(const Kernel *kernel,
std::unique_ptr<llvm::Module> mod,
std::vector<OffloadedTask> &&tasks) const override;
std::vector<LLVMCompiledData> &&data) const override;
};

#endif

TLANG_NAMESPACE_END
63 changes: 0 additions & 63 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2374,15 +2374,6 @@ bool CodeGenLLVM::maybe_read_compilation_from_cache(
return true;
}

FunctionType CodeGenLLVM::gen() {
auto compiled_res = run_compilation();

ModuleToFunctionConverter converter{
tlctx, get_llvm_program(prog)->get_runtime_executor()};
return converter.convert(kernel, std::move(compiled_res.module),
std::move(compiled_res.tasks));
}

llvm::Value *CodeGenLLVM::create_xlogue(std::unique_ptr<Block> &block) {
llvm::Value *xlogue;

Expand Down Expand Up @@ -2457,60 +2448,6 @@ void CodeGenLLVM::cache_module(const std::string &kernel_key) {
std::move(offloaded_task_list));
}

ModuleToFunctionConverter::ModuleToFunctionConverter(
TaichiLLVMContext *tlctx,
LlvmRuntimeExecutor *executor)
: tlctx_(tlctx), executor_(executor) {
}

FunctionType ModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::unique_ptr<llvm::Module> mod,
std::vector<OffloadedTask> &&tasks) const {
tlctx_->add_module(std::move(mod));

using TaskFunc = int32 (*)(void *);
std::vector<TaskFunc> task_funcs;
task_funcs.reserve(tasks.size());
for (auto &task : tasks) {
auto *func_ptr = tlctx_->lookup_function_pointer(task.name);
TI_ASSERT_INFO(func_ptr, "Offloaded task function {} not found", task.name);
task_funcs.push_back((TaskFunc)(func_ptr));
}
// Do NOT capture `this`...
return [executor = this->executor_, args, kernel_name,
task_funcs](RuntimeContext &context) {
TI_TRACE("Launching kernel {}", kernel_name);
// For taichi ndarrays, context.args saves pointer to its
// |DeviceAllocation|, CPU backend actually want to use the raw ptr here.
for (int i = 0; i < (int)args.size(); i++) {
if (args[i].is_array &&
context.device_allocation_type[i] !=
RuntimeContext::DevAllocType::kNone &&
context.array_runtime_sizes[i] > 0) {
DeviceAllocation *ptr =
static_cast<DeviceAllocation *>(context.get_arg<void *>(i));
uint64 host_ptr = (uint64)executor->get_ndarray_alloc_info_ptr(*ptr);
context.set_arg(i, host_ptr);
context.set_array_device_allocation_type(
i, RuntimeContext::DevAllocType::kNone);
}
}
for (auto task : task_funcs) {
task(&context);
}
};
}

FunctionType ModuleToFunctionConverter::convert(
const Kernel *kernel,
std::unique_ptr<llvm::Module> mod,
std::vector<OffloadedTask> &&tasks) const {
return convert(kernel->name, infer_launch_args(kernel), std::move(mod),
std::move(tasks));
}

TLANG_NAMESPACE_END

#endif // #ifdef TI_WITH_LLVM
34 changes: 8 additions & 26 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ class FunctionCreationGuard {
struct LLVMCompiledData {
std::vector<OffloadedTask> tasks;
std::unique_ptr<llvm::Module> module{nullptr};
LLVMCompiledData() = default;
LLVMCompiledData(LLVMCompiledData &&) = default;
LLVMCompiledData(std::vector<OffloadedTask> tasks,
std::unique_ptr<llvm::Module> module)
: tasks(std::move(tasks)), module(std::move(module)) {
}
TI_IO_DEF(tasks);
};

class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
Expand Down Expand Up @@ -134,7 +141,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
LLVMCompiledData run_compilation();

// TODO: This function relies largely on `run_compilation()`. Name it better.
virtual FunctionType gen();
virtual FunctionType gen(){TI_NOT_IMPLEMENTED};

virtual bool supports_offline_cache() const {
return false;
Expand Down Expand Up @@ -410,31 +417,6 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
void cache_module(const std::string &kernel_key);
};

class LlvmRuntimeExecutor;

// TODO: Make ModuleToFunctionConverter abstract,
// Move CPU implementation to "taichi/backend/cpu/"
class ModuleToFunctionConverter {
public:
explicit ModuleToFunctionConverter(TaichiLLVMContext *tlctx,
LlvmRuntimeExecutor *executor);

virtual ~ModuleToFunctionConverter() = default;

virtual FunctionType convert(const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
std::unique_ptr<llvm::Module> mod,
std::vector<OffloadedTask> &&tasks) const;

virtual FunctionType convert(const Kernel *kernel,
std::unique_ptr<llvm::Module> mod,
std::vector<OffloadedTask> &&tasks) const;

protected:
TaichiLLVMContext *tlctx_{nullptr};
LlvmRuntimeExecutor *executor_{nullptr};
};

} // namespace lang
} // namespace taichi

Expand Down
11 changes: 6 additions & 5 deletions taichi/runtime/cpu/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "taichi/runtime/llvm/llvm_offline_cache.h"
#include "taichi/runtime/llvm/llvm_runtime_executor.h"
#include "taichi/codegen/llvm/codegen_llvm.h"
#include "taichi/codegen/cpu/codegen_cpu.h"

namespace taichi {
namespace lang {
Expand All @@ -23,10 +23,11 @@ class AotModuleImpl : public LlvmAotModule {
TI_ASSERT(arch == Arch::x64 || arch == Arch::arm64);
auto *tlctx = executor_->get_llvm_context(arch);

ModuleToFunctionConverter converter{tlctx, executor_};

return converter.convert(name, loaded.args, std::move(loaded.owned_module),
std::move(loaded.offloaded_task_list));
CPUModuleToFunctionConverter converter{tlctx, executor_};
std::vector<LLVMCompiledData> data;
data.emplace_back(std::move(loaded.offloaded_task_list),
std::move(loaded.owned_module));
return converter.convert(name, loaded.args, std::move(data));
}

std::unique_ptr<aot::KernelTemplate> make_new_kernel_template(
Expand Down
6 changes: 4 additions & 2 deletions taichi/runtime/cuda/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ class AotModuleImpl : public LlvmAotModule {

CUDAModuleToFunctionConverter converter{tlctx, executor_};

return converter.convert(name, loaded.args, std::move(loaded.owned_module),
std::move(loaded.offloaded_task_list));
std::vector<LLVMCompiledData> data;
data.emplace_back(std::move(loaded.offloaded_task_list),
std::move(loaded.owned_module));
return converter.convert(name, loaded.args, std::move(data));
}

std::unique_ptr<aot::KernelTemplate> make_new_kernel_template(
Expand Down