Skip to content

Commit

Permalink
[llvm] [refactor] (Decomp of #5251 8/n) Refactor KernelCacheData (#5383)
Browse files Browse the repository at this point in the history
* [llvm] (Decomp of #5251 8/n) Refactor KernelCacheData

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Jul 11, 2022
1 parent 3459901 commit 0dfd166
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 90 deletions.
22 changes: 10 additions & 12 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2595,15 +2595,9 @@ bool CodeGenLLVM::maybe_read_compilation_from_cache(
if (!reader->get_kernel_cache(cache_data, kernel_key, llvm_ctx)) {
return false;
}
this->module = std::move(cache_data.owned_module);
for (auto &task : cache_data.offloaded_task_list) {
auto &t = this->offloaded_tasks.emplace_back(task.name);
t.block_dim = task.block_dim;
t.grid_dim = task.grid_dim;
}
data->tasks = std::move(cache_data.compiled_data_list[0].tasks);
data->module = std::move(cache_data.compiled_data_list[0].module);
kernel->set_from_offline_cache();
data->tasks = std::move(this->offloaded_tasks);
data->module = std::move(this->module);
return true;
}

Expand Down Expand Up @@ -2679,10 +2673,14 @@ void CodeGenLLVM::visit(FuncCallStmt *stmt) {
}

void CodeGenLLVM::cache_module(const std::string &kernel_key) {
std::vector<OffloadedTask> offloaded_task_list = offloaded_tasks;
get_llvm_program(prog)->cache_kernel(kernel_key, this->module.get(),
infer_launch_args(kernel),
std::move(offloaded_task_list));
std::vector<LLVMCompiledData> data;
data.emplace_back(offloaded_tasks, llvm::CloneModule(*module));
get_llvm_program(prog)->cache_kernel(kernel_key, data,
infer_launch_args(kernel));
}

LLVMCompiledData LLVMCompiledData::clone() const {
return {tasks, llvm::CloneModule(*module)};
}

TLANG_NAMESPACE_END
Expand Down
1 change: 1 addition & 0 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ struct LLVMCompiledData {
std::unique_ptr<llvm::Module> module)
: tasks(std::move(tasks)), module(std::move(module)) {
}
LLVMCompiledData clone() const;
TI_IO_DEF(tasks);
};

Expand Down
6 changes: 2 additions & 4 deletions taichi/runtime/cpu/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ class AotModuleImpl : public LlvmAotModule {
auto *tlctx = executor_->get_llvm_context(arch);

CPUModuleToFunctionConverter converter{tlctx, executor_};
std::vector<LLVMCompiledData> data;
data.emplace_back(std::move(loaded.offloaded_task_list),
std::move(loaded.owned_module));
return converter.convert(name, loaded.args, std::move(data));
return converter.convert(name, loaded.args,
std::move(loaded.compiled_data_list));
}

std::unique_ptr<aot::KernelTemplate> make_new_kernel_template(
Expand Down
7 changes: 2 additions & 5 deletions taichi/runtime/cuda/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@ class AotModuleImpl : public LlvmAotModule {
auto *tlctx = executor_->get_llvm_context(arch);

CUDAModuleToFunctionConverter converter{tlctx, executor_};

std::vector<LLVMCompiledData> data;
data.emplace_back(std::move(loaded.offloaded_task_list),
std::move(loaded.owned_module));
return converter.convert(name, loaded.args, std::move(data));
return converter.convert(name, loaded.args,
std::move(loaded.compiled_data_list));
}

std::unique_ptr<aot::KernelTemplate> make_new_kernel_template(
Expand Down
4 changes: 1 addition & 3 deletions taichi/runtime/llvm/llvm_aot_module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ void LlvmAotModuleBuilder::add_per_backend(const std::string &identifier,
auto compiled = compile_kernel(kernel);
LlvmOfflineCache::KernelCacheData kcache;
kcache.kernel_key = identifier;
kcache.module = compiled.module.get();
kcache.owned_module = std::move(compiled.module);
kcache.compiled_data_list.push_back(std::move(compiled));
kcache.args = infer_launch_args(kernel);
kcache.offloaded_task_list = std::move(compiled.tasks);
kcache.last_used_at = std::time(nullptr);
kcache.created_at = std::time(nullptr);
cache_.kernels[identifier] = std::move(kcache);
Expand Down
86 changes: 51 additions & 35 deletions taichi/runtime/llvm/llvm_offline_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,16 @@ bool LlvmOfflineCacheFileReader::get_kernel_cache(
}

auto &kernel_data = itr->second;
if (kernel_data.owned_module == nullptr) {
const std::string filename_prefix = taichi::join_path(path_, key);
kernel_data.owned_module = load_module(filename_prefix, key, llvm_ctx);
TI_ASSERT(kernel_data.owned_module != nullptr);
kernel_data.module = kernel_data.owned_module.get();
for (int i = 0; i < kernel_data.compiled_data_list.size(); i++) {
auto &data = kernel_data.compiled_data_list[i];
if (!data.module) {
std::string filename_prefix =
taichi::join_path(path_, key + "." + std::to_string(i));
data.module = load_module(filename_prefix, key, llvm_ctx);
TI_ASSERT(data.module);
}
res.compiled_data_list.emplace_back(data.tasks,
llvm::CloneModule(*data.module));
}

kernel_data.last_used_at = std::time(nullptr);
Expand All @@ -121,9 +126,6 @@ bool LlvmOfflineCacheFileReader::get_kernel_cache(
res.last_used_at = kernel_data.last_used_at;
res.kernel_key = key;
res.args = kernel_data.args;
res.offloaded_task_list = kernel_data.offloaded_task_list;
res.owned_module = llvm::CloneModule(*kernel_data.module);
res.module = res.owned_module.get();
return true;
}

Expand Down Expand Up @@ -168,22 +170,24 @@ void LlvmOfflineCacheFileWriter::dump(const std::string &path,
return llvm_os.tell();
};
{
auto *mod = v.module;
if (!mod) {
mod = v.owned_module.get();
}
TI_ASSERT(mod != nullptr);

mangle_offloaded_task_name(k, mod, v.offloaded_task_list);
if (format & Format::LL) {
size += write_llvm_module(".ll", [mod](llvm::raw_os_ostream &os) {
mod->print(os, /*AAW=*/nullptr);
});
}
if (format & Format::BC) {
size += write_llvm_module(".bc", [mod](llvm::raw_os_ostream &os) {
llvm::WriteBitcodeToFile(*mod, os);
});
mangle_offloaded_task_name(k, v.compiled_data_list);
for (int i = 0; i < v.compiled_data_list.size(); i++) {
auto &data = v.compiled_data_list[i];
auto *mod = data.module.get();
TI_ASSERT(mod != nullptr);
std::string suffix = "." + std::to_string(i);
if (format & Format::LL) {
size += write_llvm_module(suffix + ".ll",
[mod](llvm::raw_os_ostream &os) {
mod->print(os, /*AAW=*/nullptr);
});
}
if (format & Format::BC) {
size += write_llvm_module(suffix + ".bc",
[mod](llvm::raw_os_ostream &os) {
llvm::WriteBitcodeToFile(*mod, os);
});
}
}
}

Expand Down Expand Up @@ -240,16 +244,18 @@ void LlvmOfflineCacheFileWriter::merge_with(LlvmOfflineCache &&data) {

void LlvmOfflineCacheFileWriter::mangle_offloaded_task_name(
const std::string &kernel_key,
llvm::Module *module,
std::vector<OffloadedTask> &offloaded_task_list) {
std::vector<LLVMCompiledData> &compiled_data_list) {
if (!mangled_) {
std::size_t cnt = 0;
for (auto &e : offloaded_task_list) {
std::string mangled_name = kernel_key + std::to_string(cnt++);
auto func = module->getFunction(e.name);
TI_ASSERT(func != nullptr);
func->setName(mangled_name);
e.name = mangled_name;
for (auto &e : compiled_data_list) {
for (auto &offload : e.tasks) {
std::string mangled_name = kernel_key + std::to_string(cnt++);

auto func = e.module->getFunction(offload.name);
TI_ASSERT(func != nullptr);
func->setName(mangled_name);
offload.name = mangled_name;
}
}
}
}
Expand Down Expand Up @@ -314,9 +320,11 @@ void LlvmOfflineCacheFileWriter::clean_cache(const std::string &path,
}
TI_ASSERT(q.size() <= cnt);
while (!q.empty()) {
for (const auto &f :
get_possible_llvm_cache_filename_by_key(q.top().kernel_key)) {
taichi::remove(taichi::join_path(path, f));
for (int i = 0; i < q.top().compiled_data_list.size(); i++) {
for (const auto &f : get_possible_llvm_cache_filename_by_key(
q.top().kernel_key + "." + std::to_string(i))) {
taichi::remove(taichi::join_path(path, f));
}
}
q.pop();
}
Expand All @@ -343,5 +351,13 @@ LlvmOfflineCacheFileWriter::string_to_clean_cache_policy(
return Never;
}

LlvmOfflineCache::KernelCacheData LlvmOfflineCache::KernelCacheData::clone()
const {
std::vector<LLVMCompiledData> new_data_list;
for (const auto &data : compiled_data_list) {
new_data_list.push_back(data.clone());
}
return {kernel_key, args, std::move(new_data_list)};
}
} // namespace lang
} // namespace taichi
12 changes: 5 additions & 7 deletions taichi/runtime/llvm/llvm_offline_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ struct LlvmOfflineCache {
struct KernelCacheData {
std::string kernel_key;
std::vector<LlvmLaunchArgInfo> args;
std::vector<OffloadedTask> offloaded_task_list;

std::unique_ptr<llvm::Module> owned_module{nullptr};
llvm::Module *module{nullptr};
std::vector<LLVMCompiledData> compiled_data_list;

// For cache cleaning
std::size_t size{0}; // byte
Expand All @@ -38,9 +35,11 @@ struct LlvmOfflineCache {
KernelCacheData &operator=(KernelCacheData &&) = default;
~KernelCacheData() = default;

KernelCacheData clone() const;

TI_IO_DEF(kernel_key,
args,
offloaded_task_list,
compiled_data_list,
size,
created_at,
last_used_at);
Expand Down Expand Up @@ -185,8 +184,7 @@ class LlvmOfflineCacheFileWriter {

void mangle_offloaded_task_name(
const std::string &kernel_key,
llvm::Module *module,
std::vector<OffloadedTask> &offloaded_task_list);
std::vector<LLVMCompiledData> &compiled_data_list);

LlvmOfflineCache data_;
bool mangled_{false};
Expand Down
22 changes: 9 additions & 13 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,32 +107,28 @@ std::unique_ptr<aot::Kernel> LlvmProgramImpl::make_aot_kernel(Kernel &kernel) {
TI_ASSERT(cache_data_->kernels.count(kernel_key));
const LlvmOfflineCache::KernelCacheData &kernel_data =
cache_data_->kernels[kernel_key];

LlvmOfflineCache::KernelCacheData compiled_kernel;
LlvmOfflineCache::KernelCacheData compiled_kernel = kernel_data.clone();
compiled_kernel.kernel_key = kernel.get_name();
compiled_kernel.owned_module =
llvm::CloneModule(*kernel_data.owned_module.get());
compiled_kernel.args = kernel_data.args;
compiled_kernel.offloaded_task_list = kernel_data.offloaded_task_list;
return std::make_unique<llvm_aot::KernelImpl>(compiled_fn, kernel.get_name(),
std::move(compiled_kernel));
}

void LlvmProgramImpl::cache_kernel(
const std::string &kernel_key,
llvm::Module *module,
std::vector<LlvmLaunchArgInfo> &&args,
std::vector<OffloadedTask> &&offloaded_task_list) {
const std::vector<LLVMCompiledData> &data_list,
std::vector<LlvmLaunchArgInfo> &&args) {
if (cache_data_->kernels.find(kernel_key) != cache_data_->kernels.end()) {
return;
}
auto &kernel_cache = cache_data_->kernels[kernel_key];
kernel_cache.created_at = std::time(nullptr);
kernel_cache.last_used_at = std::time(nullptr);
kernel_cache.kernel_key = kernel_key;
kernel_cache.owned_module = llvm::CloneModule(*module);
for (const auto &data : data_list) {
kernel_cache.compiled_data_list.emplace_back(
data.tasks, llvm::CloneModule(*data.module));
}
kernel_cache.args = std::move(args);
kernel_cache.offloaded_task_list = std::move(offloaded_task_list);
kernel_cache.created_at = std::time(nullptr);
kernel_cache.last_used_at = std::time(nullptr);
}

void LlvmProgramImpl::cache_field(int snode_tree_id,
Expand Down
6 changes: 3 additions & 3 deletions taichi/runtime/program_impls/llvm/llvm_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ class LlvmProgramImpl : public ProgramImpl {
void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override;

void cache_kernel(const std::string &kernel_key,
llvm::Module *module,
std::vector<LlvmLaunchArgInfo> &&args,
std::vector<OffloadedTask> &&offloaded_task_list);
const std::vector<LLVMCompiledData> &data_list,
std::vector<LlvmLaunchArgInfo> &&args);
;

void cache_field(int snode_tree_id,
int root_id,
Expand Down
20 changes: 12 additions & 8 deletions tests/cpp/llvm/llvm_offline_cache_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,13 @@ TEST_P(LlvmOfflineCacheTest, ReadWrite) {
kcache.created_at = 1;
kcache.last_used_at = 1;
kcache.kernel_key = kKernelName;
kcache.owned_module = make_module(*llvm_ctx);
kcache.module = kcache.owned_module.get();
kcache.offloaded_task_list.emplace_back(kTaskName, kBlockDim, kGridDim);
std::vector<OffloadedTask> tasks;
OffloadedTask task;
task.name = kTaskName;
task.block_dim = kBlockDim;
task.grid_dim = kGridDim;
tasks.push_back(task);
kcache.compiled_data_list.emplace_back(tasks, make_module(*llvm_ctx));
kcache.args = arg_infos;
writer.add_kernel_cache(kKernelName, std::move(kcache));
writer.set_no_mangle();
Expand All @@ -115,13 +119,13 @@ TEST_P(LlvmOfflineCacheTest, ReadWrite) {
const bool ok = reader->get_kernel_cache(kcache, kKernelName, *llvm_ctx);
ASSERT_TRUE(ok);
EXPECT_EQ(kcache.kernel_key, kKernelName);
EXPECT_EQ(kcache.offloaded_task_list.size(), 1);
const auto &task0 = kcache.offloaded_task_list.front();
EXPECT_EQ(kcache.compiled_data_list[0].tasks.size(), 1);
const auto &task0 = kcache.compiled_data_list[0].tasks.front();
EXPECT_EQ(task0.name, kTaskName);

ASSERT_NE(kcache.owned_module, nullptr);
kcache.module->dump();
tlctx_->add_module(std::move(kcache.owned_module));
ASSERT_NE(kcache.compiled_data_list[0].module, nullptr);
kcache.compiled_data_list[0].module->dump();
tlctx_->add_module(std::move(kcache.compiled_data_list[0].module));
using FuncType = int (*)(int, int);
FuncType my_add = (FuncType)tlctx_->lookup_function_pointer(kTaskName);
const auto res = my_add(40, 2);
Expand Down

0 comments on commit 0dfd166

Please sign in to comment.