Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm] [refactor] (Decomp of #5251 9/n) Refactor CodeGen to support parallel compilation on LLVM backend #5387

Merged
merged 6 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#if defined(TI_WITH_LLVM)
#include "taichi/codegen/cpu/codegen_cpu.h"
#include "taichi/codegen/wasm/codegen_wasm.h"
#include "taichi/runtime/llvm/llvm_offline_cache.h"
#include "taichi/runtime/program_impls/llvm/llvm_program.h"
#endif
#if defined(TI_WITH_CUDA)
#include "taichi/codegen/cuda/codegen_cuda.h"
Expand Down Expand Up @@ -53,6 +55,34 @@ std::unique_ptr<KernelCodeGen> KernelCodeGen::create(Arch arch,
}
#ifdef TI_WITH_LLVM

bool KernelCodeGen::maybe_read_compilation_from_cache(
const std::string &kernel_key,
std::vector<LLVMCompiledData> &data) {
const auto &config = prog->config;
auto reader =
LlvmOfflineCacheFileReader::make(config.offline_cache_file_path);
if (!reader) {
return false;
}

LlvmOfflineCache::KernelCacheData cache_data;
auto *tlctx = get_llvm_program(prog)->get_llvm_context(config.arch);
auto &llvm_ctx = *tlctx->get_this_thread_context();

if (!reader->get_kernel_cache(cache_data, kernel_key, llvm_ctx)) {
return false;
}
data.swap(cache_data.compiled_data_list);
kernel->set_from_offline_cache();
return true;
}

void KernelCodeGen::cache_module(const std::string &kernel_key,
const std::vector<LLVMCompiledData> &data) {
get_llvm_program(prog)->cache_kernel(kernel_key, data,
infer_launch_args(kernel));
}

ModuleToFunctionConverter::ModuleToFunctionConverter(
TaichiLLVMContext *tlctx,
LlvmRuntimeExecutor *executor)
Expand Down
8 changes: 8 additions & 0 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@ class KernelCodeGen {
Stmt *stmt = nullptr);

virtual FunctionType codegen() = 0;
virtual bool supports_offline_cache() const {
return false;
}

#ifdef TI_WITH_LLVM
virtual LLVMCompiledData modulegen(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) {
TI_NOT_IMPLEMENTED
}
bool maybe_read_compilation_from_cache(const std::string &kernel_key,
std::vector<LLVMCompiledData> &data);
void cache_module(const std::string &kernel_key,
const std::vector<LLVMCompiledData> &data);
#endif
};

Expand Down
56 changes: 39 additions & 17 deletions taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/util/statistics.h"

#include "taichi/ir/transforms.h"
#include "taichi/ir/analysis.h"
#include "taichi/analysis/offline_cache_util.h"
TLANG_NAMESPACE_BEGIN

namespace {
Expand All @@ -22,10 +24,6 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
TI_AUTO_PROF
}

bool supports_offline_cache() const override {
return true;
}

void create_offload_range_for(OffloadedStmt *stmt) override {
int step = 1;

Expand Down Expand Up @@ -221,16 +219,6 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
TI_NOT_IMPLEMENTED
}
}

FunctionType gen() override {
auto compiled_res = run_compilation();

CPUModuleToFunctionConverter converter{
tlctx, get_llvm_program(prog)->get_runtime_executor()};
std::vector<LLVMCompiledData> data;
data.push_back(std::move(compiled_res));
return converter.convert(kernel, std::move(data));
}
};

} // namespace
Expand Down Expand Up @@ -286,11 +274,45 @@ FunctionType CPUModuleToFunctionConverter::convert(
};
}

LLVMCompiledData CodeGenCPU::modulegen(std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
CodeGenLLVMCPU gen(kernel, stmt);
return gen.run_compilation();
}
#endif // TI_WITH_LLVM

FunctionType CodeGenCPU::codegen() {
TI_AUTO_PROF;
return CodeGenLLVMCPU(kernel, ir).gen();
}
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);
auto &config = prog->config;
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
if (config.offline_cache && !config.async_mode &&
this->supports_offline_cache() && !kernel->is_evaluator) {
std::vector<LLVMCompiledData> res;
const bool ok = maybe_read_compilation_from_cache(kernel_key, res);
if (ok) {
CPUModuleToFunctionConverter converter(
tlctx, get_llvm_program(prog)->get_runtime_executor());
return converter.convert(kernel, std::move(res));
}
}
if (!kernel->lowered()) {
kernel->lower(/*to_executable=*/false);
}

CodeGenLLVMCPU gen(kernel, ir);
auto compiled_res = gen.run_compilation();

CPUModuleToFunctionConverter converter{gen.tlctx,
llvm_prog->get_runtime_executor()};
std::vector<LLVMCompiledData> data_list;
data_list.push_back(std::move(compiled_res));
if (!kernel->is_evaluator) {
cache_module(kernel_key, data_list);
}

return converter.convert(this->kernel, std::move(data_list));
}
TLANG_NAMESPACE_END
6 changes: 6 additions & 0 deletions taichi/codegen/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class CodeGenCPU : public KernelCodeGen {
IRNode *ir);
#endif // TI_WITH_LLVM

bool supports_offline_cache() const override {
return true;
}
LLVMCompiledData modulegen(std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override;

FunctionType codegen() override;
};

Expand Down
50 changes: 33 additions & 17 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "taichi/rhi/cuda/cuda_context.h"
#include "taichi/runtime/program_impls/llvm/llvm_program.h"
#include "taichi/util/action_recorder.h"

#include "taichi/analysis/offline_cache_util.h"
TLANG_NAMESPACE_BEGIN

using namespace llvm;
Expand All @@ -31,21 +31,6 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
: CodeGenLLVM(kernel, ir) {
}

bool supports_offline_cache() const override {
return true;
}

FunctionType gen() override {
auto compiled_res = run_compilation();

auto *llvm_prog = get_llvm_program(kernel->program);
CUDAModuleToFunctionConverter converter{tlctx,
llvm_prog->get_runtime_executor()};
std::vector<LLVMCompiledData> data;
data.push_back(std::move(compiled_res));
return converter.convert(this->kernel, std::move(data));
}

llvm::Value *create_print(std::string tag,
DataType dt,
llvm::Value *value) override {
Expand Down Expand Up @@ -737,7 +722,38 @@ static void set_arg_external_array(RuntimeContext *ctx,

FunctionType CodeGenCUDA::codegen() {
TI_AUTO_PROF
return CodeGenLLVMCUDA(kernel, ir).gen();
// TODO: move the offline cache part to the base class
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);
auto &config = prog->config;
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
if (config.offline_cache && !config.async_mode &&
this->supports_offline_cache() && !kernel->is_evaluator) {
std::vector<LLVMCompiledData> res;
const bool ok = maybe_read_compilation_from_cache(kernel_key, res);
if (ok) {
CUDAModuleToFunctionConverter converter(
tlctx, get_llvm_program(prog)->get_runtime_executor());
return converter.convert(kernel, std::move(res));
}
}
if (!kernel->lowered()) {
kernel->lower(/*to_executable=*/false);
}

CodeGenLLVMCUDA gen(kernel, ir);
auto compiled_res = gen.run_compilation();

CUDAModuleToFunctionConverter converter{gen.tlctx,
llvm_prog->get_runtime_executor()};
std::vector<LLVMCompiledData> data_list;
data_list.push_back(std::move(compiled_res));
if (!kernel->is_evaluator) {
cache_module(kernel_key, data_list);
}

return converter.convert(this->kernel, std::move(data_list));
}

FunctionType CUDAModuleToFunctionConverter::convert(
Expand Down
4 changes: 4 additions & 0 deletions taichi/codegen/cuda/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class CodeGenCUDA : public KernelCodeGen {
IRNode *ir);
#endif // TI_WITH_LLVM

bool supports_offline_cache() const override {
return true;
}

FunctionType codegen() override;
};

Expand Down
60 changes: 5 additions & 55 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2547,58 +2547,15 @@ void CodeGenLLVM::emit_to_module() {
}

LLVMCompiledData CodeGenLLVM::run_compilation() {
const auto &config = prog->config;
std::string kernel_key =
get_hashed_offline_cache_key(&kernel->program->config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
if (config.offline_cache && !config.async_mode &&
this->supports_offline_cache() && !kernel->is_evaluator) {
LLVMCompiledData res;
const bool ok = maybe_read_compilation_from_cache(kernel_key, &res);
if (ok) {
return res;
}
}
// Final lowering

auto config = kernel->program->config;
kernel->offload_to_executable(ir);

if (!kernel->lowered()) {
kernel->lower();
}
emit_to_module();
eliminate_unused_functions();

// Updates LlvmProgramImpl->cache_data_ to save the compiled kernel
// information for successive uses in AOT or CGraph.
if (!kernel->is_evaluator) {
cache_module(kernel_key);
}

LLVMCompiledData res;
res.tasks = std::move(this->offloaded_tasks);
res.module = std::move(this->module);
return res;
}

bool CodeGenLLVM::maybe_read_compilation_from_cache(
const std::string &kernel_key,
LLVMCompiledData *data) {
const auto &config = prog->config;
auto reader =
LlvmOfflineCacheFileReader::make(config.offline_cache_file_path);
if (!reader) {
return false;
}

LlvmOfflineCache::KernelCacheData cache_data;
auto *tlctx = get_llvm_program(prog)->get_llvm_context(config.arch);
auto &llvm_ctx = *tlctx->get_this_thread_context();

if (!reader->get_kernel_cache(cache_data, kernel_key, llvm_ctx)) {
return false;
}
data->tasks = std::move(cache_data.compiled_data_list[0].tasks);
data->module = std::move(cache_data.compiled_data_list[0].module);
kernel->set_from_offline_cache();
return true;
return {std::move(this->offloaded_tasks), std::move(this->module)};
}

llvm::Value *CodeGenLLVM::create_xlogue(std::unique_ptr<Block> &block) {
Expand Down Expand Up @@ -2672,13 +2629,6 @@ void CodeGenLLVM::visit(FuncCallStmt *stmt) {
}
}

void CodeGenLLVM::cache_module(const std::string &kernel_key) {
std::vector<LLVMCompiledData> data;
data.emplace_back(offloaded_tasks, llvm::CloneModule(*module));
get_llvm_program(prog)->cache_kernel(kernel_key, data,
infer_launch_args(kernel));
}

LLVMCompiledData LLVMCompiledData::clone() const {
return {tasks, llvm::CloneModule(*module)};
}
Expand Down
16 changes: 1 addition & 15 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
*
* @return LLVMCompiledData
*/
LLVMCompiledData run_compilation();

// TODO: This function relies largely on `run_compilation()`. Name it better.
virtual FunctionType gen(){TI_NOT_IMPLEMENTED};

virtual bool supports_offline_cache() const {
return false;
}

virtual LLVMCompiledData run_compilation();
// For debugging only
virtual llvm::Value *create_print(std::string tag,
DataType dt,
Expand Down Expand Up @@ -429,12 +421,6 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type);

~CodeGenLLVM() override = default;

private:
bool maybe_read_compilation_from_cache(const std::string &kernel_key,
LLVMCompiledData *data);

void cache_module(const std::string &kernel_key);
};

} // namespace lang
Expand Down
24 changes: 14 additions & 10 deletions taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ class CodeGenLLVMWASM : public CodeGenLLVM {
TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs()));
}

FunctionType gen() override {
TI_AUTO_PROF
LLVMCompiledData run_compilation() override {
// lower kernel
if (!kernel->lowered()) {
kernel->lower();
Expand All @@ -236,19 +235,24 @@ class CodeGenLLVMWASM : public CodeGenLLVM {
}
return func_name == offloaded_task_name;
});
tlctx->add_module(std::move(module));
auto kernel_symbol = tlctx->lookup_function_pointer(offloaded_task_name);
return [=](RuntimeContext &context) {
TI_TRACE("Launching Taichi Kernel Function");
auto func = (int32(*)(void *))kernel_symbol;
func(&context);
};
LLVMCompiledData res;
res.tasks.emplace_back(offloaded_task_name);
res.module = std::move(this->module);
return res;
}
};

FunctionType CodeGenWASM::codegen() {
TI_AUTO_PROF
return CodeGenLLVMWASM(kernel, ir).gen();
CodeGenLLVMWASM gen(kernel, ir);
auto res = gen.run_compilation();
gen.tlctx->add_module(std::move(res.module));
auto kernel_symbol = gen.tlctx->lookup_function_pointer(res.tasks[0].name);
return [=](RuntimeContext &context) {
TI_TRACE("Launching Taichi Kernel Function");
auto func = (int32(*)(void *))kernel_symbol;
func(&context);
};
}

LLVMCompiledData CodeGenWASM::modulegen(std::unique_ptr<llvm::Module> &&module,
Expand Down
Loading