Skip to content

Commit

Permalink
[llvm] [refactor] (Decomp of #5251 9/n) Refactor CodeGen to support p…
Browse files Browse the repository at this point in the history
…arallel compilation on LLVM backend (#5387)

* [llvm] [refactor] (Decomp of #5251 9/n) Refactor CodeGen to support parallel compilation on CPU

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Jul 11, 2022
1 parent 113981f commit 577cbec
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 114 deletions.
30 changes: 30 additions & 0 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#if defined(TI_WITH_LLVM)
#include "taichi/codegen/cpu/codegen_cpu.h"
#include "taichi/codegen/wasm/codegen_wasm.h"
#include "taichi/runtime/llvm/llvm_offline_cache.h"
#include "taichi/runtime/program_impls/llvm/llvm_program.h"
#endif
#if defined(TI_WITH_CUDA)
#include "taichi/codegen/cuda/codegen_cuda.h"
Expand Down Expand Up @@ -53,6 +55,34 @@ std::unique_ptr<KernelCodeGen> KernelCodeGen::create(Arch arch,
}
#ifdef TI_WITH_LLVM

bool KernelCodeGen::maybe_read_compilation_from_cache(
const std::string &kernel_key,
std::vector<LLVMCompiledData> &data) {
const auto &config = prog->config;
auto reader =
LlvmOfflineCacheFileReader::make(config.offline_cache_file_path);
if (!reader) {
return false;
}

LlvmOfflineCache::KernelCacheData cache_data;
auto *tlctx = get_llvm_program(prog)->get_llvm_context(config.arch);
auto &llvm_ctx = *tlctx->get_this_thread_context();

if (!reader->get_kernel_cache(cache_data, kernel_key, llvm_ctx)) {
return false;
}
data.swap(cache_data.compiled_data_list);
kernel->set_from_offline_cache();
return true;
}

void KernelCodeGen::cache_module(const std::string &kernel_key,
const std::vector<LLVMCompiledData> &data) {
get_llvm_program(prog)->cache_kernel(kernel_key, data,
infer_launch_args(kernel));
}

ModuleToFunctionConverter::ModuleToFunctionConverter(
TaichiLLVMContext *tlctx,
LlvmRuntimeExecutor *executor)
Expand Down
8 changes: 8 additions & 0 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@ class KernelCodeGen {
Stmt *stmt = nullptr);

virtual FunctionType codegen() = 0;
virtual bool supports_offline_cache() const {
return false;
}

#ifdef TI_WITH_LLVM
virtual LLVMCompiledData modulegen(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) {
TI_NOT_IMPLEMENTED
}
bool maybe_read_compilation_from_cache(const std::string &kernel_key,
std::vector<LLVMCompiledData> &data);
void cache_module(const std::string &kernel_key,
const std::vector<LLVMCompiledData> &data);
#endif
};

Expand Down
56 changes: 39 additions & 17 deletions taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/util/statistics.h"

#include "taichi/ir/transforms.h"
#include "taichi/ir/analysis.h"
#include "taichi/analysis/offline_cache_util.h"
TLANG_NAMESPACE_BEGIN

namespace {
Expand All @@ -22,10 +24,6 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
TI_AUTO_PROF
}

bool supports_offline_cache() const override {
return true;
}

void create_offload_range_for(OffloadedStmt *stmt) override {
int step = 1;

Expand Down Expand Up @@ -221,16 +219,6 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
TI_NOT_IMPLEMENTED
}
}

FunctionType gen() override {
auto compiled_res = run_compilation();

CPUModuleToFunctionConverter converter{
tlctx, get_llvm_program(prog)->get_runtime_executor()};
std::vector<LLVMCompiledData> data;
data.push_back(std::move(compiled_res));
return converter.convert(kernel, std::move(data));
}
};

} // namespace
Expand Down Expand Up @@ -286,11 +274,45 @@ FunctionType CPUModuleToFunctionConverter::convert(
};
}

LLVMCompiledData CodeGenCPU::modulegen(std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
CodeGenLLVMCPU gen(kernel, stmt);
return gen.run_compilation();
}
#endif // TI_WITH_LLVM

FunctionType CodeGenCPU::codegen() {
TI_AUTO_PROF;
return CodeGenLLVMCPU(kernel, ir).gen();
}
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);
auto &config = prog->config;
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
if (config.offline_cache && !config.async_mode &&
this->supports_offline_cache() && !kernel->is_evaluator) {
std::vector<LLVMCompiledData> res;
const bool ok = maybe_read_compilation_from_cache(kernel_key, res);
if (ok) {
CPUModuleToFunctionConverter converter(
tlctx, get_llvm_program(prog)->get_runtime_executor());
return converter.convert(kernel, std::move(res));
}
}
if (!kernel->lowered()) {
kernel->lower(/*to_executable=*/false);
}

CodeGenLLVMCPU gen(kernel, ir);
auto compiled_res = gen.run_compilation();

CPUModuleToFunctionConverter converter{gen.tlctx,
llvm_prog->get_runtime_executor()};
std::vector<LLVMCompiledData> data_list;
data_list.push_back(std::move(compiled_res));
if (!kernel->is_evaluator) {
cache_module(kernel_key, data_list);
}

return converter.convert(this->kernel, std::move(data_list));
}
TLANG_NAMESPACE_END
6 changes: 6 additions & 0 deletions taichi/codegen/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class CodeGenCPU : public KernelCodeGen {
IRNode *ir);
#endif // TI_WITH_LLVM

bool supports_offline_cache() const override {
return true;
}
LLVMCompiledData modulegen(std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override;

FunctionType codegen() override;
};

Expand Down
50 changes: 33 additions & 17 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "taichi/rhi/cuda/cuda_context.h"
#include "taichi/runtime/program_impls/llvm/llvm_program.h"
#include "taichi/util/action_recorder.h"

#include "taichi/analysis/offline_cache_util.h"
TLANG_NAMESPACE_BEGIN

using namespace llvm;
Expand All @@ -31,21 +31,6 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
: CodeGenLLVM(kernel, ir) {
}

bool supports_offline_cache() const override {
return true;
}

FunctionType gen() override {
auto compiled_res = run_compilation();

auto *llvm_prog = get_llvm_program(kernel->program);
CUDAModuleToFunctionConverter converter{tlctx,
llvm_prog->get_runtime_executor()};
std::vector<LLVMCompiledData> data;
data.push_back(std::move(compiled_res));
return converter.convert(this->kernel, std::move(data));
}

llvm::Value *create_print(std::string tag,
DataType dt,
llvm::Value *value) override {
Expand Down Expand Up @@ -737,7 +722,38 @@ static void set_arg_external_array(RuntimeContext *ctx,

FunctionType CodeGenCUDA::codegen() {
TI_AUTO_PROF
return CodeGenLLVMCUDA(kernel, ir).gen();
// TODO: move the offline cache part to the base class
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);
auto &config = prog->config;
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
if (config.offline_cache && !config.async_mode &&
this->supports_offline_cache() && !kernel->is_evaluator) {
std::vector<LLVMCompiledData> res;
const bool ok = maybe_read_compilation_from_cache(kernel_key, res);
if (ok) {
CUDAModuleToFunctionConverter converter(
tlctx, get_llvm_program(prog)->get_runtime_executor());
return converter.convert(kernel, std::move(res));
}
}
if (!kernel->lowered()) {
kernel->lower(/*to_executable=*/false);
}

CodeGenLLVMCUDA gen(kernel, ir);
auto compiled_res = gen.run_compilation();

CUDAModuleToFunctionConverter converter{gen.tlctx,
llvm_prog->get_runtime_executor()};
std::vector<LLVMCompiledData> data_list;
data_list.push_back(std::move(compiled_res));
if (!kernel->is_evaluator) {
cache_module(kernel_key, data_list);
}

return converter.convert(this->kernel, std::move(data_list));
}

FunctionType CUDAModuleToFunctionConverter::convert(
Expand Down
4 changes: 4 additions & 0 deletions taichi/codegen/cuda/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class CodeGenCUDA : public KernelCodeGen {
IRNode *ir);
#endif // TI_WITH_LLVM

bool supports_offline_cache() const override {
return true;
}

FunctionType codegen() override;
};

Expand Down
60 changes: 5 additions & 55 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2548,58 +2548,15 @@ void CodeGenLLVM::emit_to_module() {
}

LLVMCompiledData CodeGenLLVM::run_compilation() {
const auto &config = prog->config;
std::string kernel_key =
get_hashed_offline_cache_key(&kernel->program->config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
if (config.offline_cache && !config.async_mode &&
this->supports_offline_cache() && !kernel->is_evaluator) {
LLVMCompiledData res;
const bool ok = maybe_read_compilation_from_cache(kernel_key, &res);
if (ok) {
return res;
}
}
// Final lowering

auto config = kernel->program->config;
kernel->offload_to_executable(ir);

if (!kernel->lowered()) {
kernel->lower();
}
emit_to_module();
eliminate_unused_functions();

// Updates LlvmProgramImpl->cache_data_ to save the compiled kernel
// information for successive uses in AOT or CGraph.
if (!kernel->is_evaluator) {
cache_module(kernel_key);
}

LLVMCompiledData res;
res.tasks = std::move(this->offloaded_tasks);
res.module = std::move(this->module);
return res;
}

bool CodeGenLLVM::maybe_read_compilation_from_cache(
const std::string &kernel_key,
LLVMCompiledData *data) {
const auto &config = prog->config;
auto reader =
LlvmOfflineCacheFileReader::make(config.offline_cache_file_path);
if (!reader) {
return false;
}

LlvmOfflineCache::KernelCacheData cache_data;
auto *tlctx = get_llvm_program(prog)->get_llvm_context(config.arch);
auto &llvm_ctx = *tlctx->get_this_thread_context();

if (!reader->get_kernel_cache(cache_data, kernel_key, llvm_ctx)) {
return false;
}
data->tasks = std::move(cache_data.compiled_data_list[0].tasks);
data->module = std::move(cache_data.compiled_data_list[0].module);
kernel->set_from_offline_cache();
return true;
return {std::move(this->offloaded_tasks), std::move(this->module)};
}

llvm::Value *CodeGenLLVM::create_xlogue(std::unique_ptr<Block> &block) {
Expand Down Expand Up @@ -2673,13 +2630,6 @@ void CodeGenLLVM::visit(FuncCallStmt *stmt) {
}
}

void CodeGenLLVM::cache_module(const std::string &kernel_key) {
std::vector<LLVMCompiledData> data;
data.emplace_back(offloaded_tasks, llvm::CloneModule(*module));
get_llvm_program(prog)->cache_kernel(kernel_key, data,
infer_launch_args(kernel));
}

LLVMCompiledData LLVMCompiledData::clone() const {
return {tasks, llvm::CloneModule(*module)};
}
Expand Down
16 changes: 1 addition & 15 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
*
* @return LLVMCompiledData
*/
LLVMCompiledData run_compilation();

// TODO: This function relies largely on `run_compilation()`. Name it better.
virtual FunctionType gen(){TI_NOT_IMPLEMENTED};

virtual bool supports_offline_cache() const {
return false;
}

virtual LLVMCompiledData run_compilation();
// For debugging only
virtual llvm::Value *create_print(std::string tag,
DataType dt,
Expand Down Expand Up @@ -432,12 +424,6 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type);

~CodeGenLLVM() override = default;

private:
bool maybe_read_compilation_from_cache(const std::string &kernel_key,
LLVMCompiledData *data);

void cache_module(const std::string &kernel_key);
};

} // namespace lang
Expand Down
24 changes: 14 additions & 10 deletions taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ class CodeGenLLVMWASM : public CodeGenLLVM {
TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs()));
}

FunctionType gen() override {
TI_AUTO_PROF
LLVMCompiledData run_compilation() override {
// lower kernel
if (!kernel->lowered()) {
kernel->lower();
Expand All @@ -236,19 +235,24 @@ class CodeGenLLVMWASM : public CodeGenLLVM {
}
return func_name == offloaded_task_name;
});
tlctx->add_module(std::move(module));
auto kernel_symbol = tlctx->lookup_function_pointer(offloaded_task_name);
return [=](RuntimeContext &context) {
TI_TRACE("Launching Taichi Kernel Function");
auto func = (int32(*)(void *))kernel_symbol;
func(&context);
};
LLVMCompiledData res;
res.tasks.emplace_back(offloaded_task_name);
res.module = std::move(this->module);
return res;
}
};

FunctionType CodeGenWASM::codegen() {
TI_AUTO_PROF
return CodeGenLLVMWASM(kernel, ir).gen();
CodeGenLLVMWASM gen(kernel, ir);
auto res = gen.run_compilation();
gen.tlctx->add_module(std::move(res.module));
auto kernel_symbol = gen.tlctx->lookup_function_pointer(res.tasks[0].name);
return [=](RuntimeContext &context) {
TI_TRACE("Launching Taichi Kernel Function");
auto func = (int32(*)(void *))kernel_symbol;
func(&context);
};
}

LLVMCompiledData CodeGenWASM::modulegen(std::unique_ptr<llvm::Module> &&module,
Expand Down
Loading

0 comments on commit 577cbec

Please sign in to comment.