Skip to content

Commit

Permalink
[vulkan] Support offline cache on Vulkan (#5825)
Browse files Browse the repository at this point in the history
* Support offline cache on Vulkan

* Impl metadata mergeing & name mangling

* Redesign offline cache path

* Refit test_offline_cache.py

* Use TI_WIP_OFFLINE_CACHE

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Make TestProgram happy

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PGZXB and pre-commit-ci[bot] authored Aug 23, 2022
1 parent 1390bcc commit 1b4006f
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 107 deletions.
20 changes: 20 additions & 0 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,26 @@ namespace offline_cache {
constexpr std::size_t offline_cache_key_length = 65;
constexpr std::size_t min_mangled_name_length = offline_cache_key_length + 2;

std::string get_cache_path_by_arch(const std::string &base_path, Arch arch) {
std::string subdir;
if (arch_uses_llvm(arch)) {
subdir = "llvm";
} else if (arch == Arch::vulkan) {
subdir = "gfx";
} else {
return base_path;
}
return taichi::join_path(base_path, subdir);
}

bool enabled_wip_offline_cache(bool enable_hint) {
// CompileConfig::offline_cache is a global option to enable offline cache on
// all backends To disable WIP offline cache by default & enable when
// developing/testing:
const char *enable_env = std::getenv("TI_WIP_OFFLINE_CACHE");
return enable_hint && enable_env && std::strncmp("1", enable_env, 1) == 0;
}

std::string mangle_name(const std::string &primal_name,
const std::string &key) {
// Result: {primal_name}{key: char[65]}_{(checksum(primal_name)) ^
Expand Down
4 changes: 4 additions & 0 deletions taichi/analysis/offline_cache_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <string>

#include "taichi/rhi/arch.h"

namespace taichi {
namespace lang {

Expand All @@ -17,6 +19,8 @@ void gen_offline_cache_key(Program *prog, IRNode *ast, std::ostream *os);

namespace offline_cache {

std::string get_cache_path_by_arch(const std::string &base_path, Arch arch);
bool enabled_wip_offline_cache(bool enable_hint);
std::string mangle_name(const std::string &primal_name, const std::string &key);
bool try_demangle_name(const std::string &mangled_name,
std::string &primal_name,
Expand Down
55 changes: 50 additions & 5 deletions taichi/runtime/gfx/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,17 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,
write_to_binary_file(ti_aot_data_, bin_path);

auto converted = AotDataConverter::convert(ti_aot_data_);
for (int i = 0; i < ti_aot_data_.kernels.size(); ++i) {
const auto &spirv_codes = ti_aot_data_.spirv_codes;
for (int i = 0; i < std::min(ti_aot_data_.kernels.size(), spirv_codes.size());
++i) {
auto &k = ti_aot_data_.kernels[i];
for (int j = 0; j < k.tasks_attribs.size(); ++j) {
std::string spv_path = write_spv_file(output_dir, k.tasks_attribs[j],
ti_aot_data_.spirv_codes[i][j]);
converted.kernels[k.name].tasks[j].source_path = spv_path;
for (int j = 0; j < std::min(k.tasks_attribs.size(), spirv_codes[i].size());
++j) {
if (!spirv_codes[i][j].empty()) {
std::string spv_path =
write_spv_file(output_dir, k.tasks_attribs[j], spirv_codes[i][j]);
converted.kernels[k.name].tasks[j].source_path = spv_path;
}
}
}

Expand All @@ -146,6 +151,46 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,
dump_graph(output_dir);
}

void AotModuleBuilderImpl::mangle_aot_data() {
// Only for offline cache
for (auto &kernel : ti_aot_data_.kernels) {
const auto &prefix = kernel.name;
for (std::size_t i = 0; i < kernel.tasks_attribs.size(); ++i) {
kernel.tasks_attribs[i].name = prefix + std::to_string(i);
}
}
}

void AotModuleBuilderImpl::merge_with_old_meta_data(const std::string &path) {
// Only for offline cache
auto filename = taichi::join_path(path, "metadata.tcb");
if (taichi::path_exists(filename)) {
TaichiAotData old_data;
read_from_binary_file(old_data, filename);
// Ignore root_buffer_size and fields which aren't needed for offline cache
ti_aot_data_.kernels.insert(ti_aot_data_.kernels.end(),
old_data.kernels.begin(),
old_data.kernels.end());
}
}

std::optional<GfxRuntime::RegisterParams>
AotModuleBuilderImpl::try_get_kernel_register_params(
const std::string &kernel_name) const {
const auto &kernels = ti_aot_data_.kernels;
for (std::size_t i = 0; i < kernels.size(); ++i) {
if (kernels[i].name == kernel_name) {
GfxRuntime::RegisterParams result;
result.kernel_attribs = kernels[i];
result.task_spirv_source_codes = ti_aot_data_.spirv_codes[i];
// We only support a single SNodeTree during AOT.
result.num_snode_trees = 1;
return result;
}
}
return std::nullopt;
}

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
spirv::lower(kernel);
Expand Down
5 changes: 5 additions & 0 deletions taichi/runtime/gfx/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
void dump(const std::string &output_dir,
const std::string &filename) const override;

void mangle_aot_data();
void merge_with_old_meta_data(const std::string &path);
std::optional<GfxRuntime::RegisterParams> try_get_kernel_register_params(
const std::string &kernel_name) const;

private:
void add_per_backend(const std::string &identifier, Kernel *kernel) override;

Expand Down
13 changes: 9 additions & 4 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "taichi/codegen/llvm/struct_llvm.h"
#include "taichi/runtime/llvm/aot_graph_data.h"
#include "taichi/runtime/llvm/llvm_offline_cache.h"
#include "taichi/analysis/offline_cache_util.h"
#include "taichi/runtime/cpu/aot_module_builder_impl.h"

#if defined(TI_WITH_CUDA)
Expand All @@ -30,7 +31,8 @@ LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_,
cache_data_ = std::make_unique<LlvmOfflineCache>();
if (config_.offline_cache) {
cache_reader_ =
LlvmOfflineCacheFileReader::make(config_.offline_cache_file_path);
LlvmOfflineCacheFileReader::make(offline_cache::get_cache_path_by_arch(
config_.offline_cache_file_path, config->arch));
}
}

Expand Down Expand Up @@ -182,16 +184,19 @@ void LlvmProgramImpl::dump_cache_data_to_disk() {
auto policy = LlvmOfflineCacheFileWriter::string_to_clean_cache_policy(
config->offline_cache_cleaning_policy);
LlvmOfflineCacheFileWriter::clean_cache(
config->offline_cache_file_path, policy,
config->offline_cache_max_size_of_files,
offline_cache::get_cache_path_by_arch(config->offline_cache_file_path,
config->arch),
policy, config->offline_cache_max_size_of_files,
config->offline_cache_cleaning_factor);
if (!cache_data_->kernels.empty()) {
LlvmOfflineCacheFileWriter writer{};
writer.set_data(std::move(cache_data_));

// Note: For offline-cache, new-metadata should be merged with
// old-metadata
writer.dump(config->offline_cache_file_path, LlvmOfflineCache::LL, true);
writer.dump(offline_cache::get_cache_path_by_arch(
config->offline_cache_file_path, config->arch),
LlvmOfflineCache::LL, true);
}
}
}
Expand Down
88 changes: 83 additions & 5 deletions taichi/runtime/program_impls/vulkan/vulkan_program.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "taichi/runtime/program_impls/vulkan/vulkan_program.h"

#include "taichi/analysis/offline_cache_util.h"
#include "taichi/aot/graph_data.h"
#include "taichi/runtime/gfx/aot_module_builder_impl.h"
#include "taichi/runtime/gfx/snode_tree_manager.h"
#include "taichi/runtime/gfx/aot_module_loader_impl.h"
Expand Down Expand Up @@ -66,19 +68,57 @@ VulkanProgramImpl::VulkanProgramImpl(CompileConfig &config)
: ProgramImpl(config) {
}

FunctionType register_params_to_executable(
gfx::GfxRuntime::RegisterParams &&params,
gfx::GfxRuntime *runtime) {
auto handle = runtime->register_taichi_kernel(std::move(params));
return [runtime, handle](RuntimeContext &ctx) {
runtime->launch_kernel(handle, &ctx);
};
}

FunctionType compile_to_executable(Kernel *kernel,
gfx::GfxRuntime *runtime,
gfx::SNodeTreeManager *snode_tree_mgr) {
auto handle = runtime->register_taichi_kernel(
return register_params_to_executable(
gfx::run_codegen(kernel, runtime->get_ti_device(),
snode_tree_mgr->get_compiled_structs()));
return [runtime, handle](RuntimeContext &ctx) {
runtime->launch_kernel(handle, &ctx);
};
snode_tree_mgr->get_compiled_structs()),
runtime);
}

FunctionType VulkanProgramImpl::compile(Kernel *kernel,
OffloadedStmt *offloaded) {
// The Vulkan offline cache depends on AOT, which only supports a single
// SNodeTree. Hacking aot::Module can resolve this problem, but we prefer to
// fix it after supporting multiple SNodeTrees in AOT.
if (offline_cache::enabled_wip_offline_cache(config->offline_cache) &&
!kernel->is_evaluator &&
snode_tree_mgr_->get_compiled_structs().size() == 1) {
auto kernel_key = get_hashed_offline_cache_key(config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
const auto &cached_module = get_cached_module();
aot::Kernel *cached_kernel = nullptr;
if (cached_module &&
(cached_kernel = cached_module->get_kernel(kernel_key))) {
TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(),
kernel_key);
kernel->set_from_offline_cache();
return
[cached_kernel](RuntimeContext &ctx) { cached_kernel->launch(&ctx); };
} else { // Compile & Cache it
TI_DEBUG("Cache kernel '{}' (key='{}')", kernel->get_name(), kernel_key);
auto *cache_builder = static_cast<gfx::AotModuleBuilderImpl *>(
get_caching_module_builder().get());
TI_ASSERT(cache_builder != nullptr);
cache_builder->add(kernel_key, kernel);
auto params_opt =
cache_builder->try_get_kernel_register_params(kernel_key);
TI_ASSERT(params_opt.has_value());
return register_params_to_executable(std::move(params_opt.value()),
vulkan_runtime_.get());
}
}

spirv::lower(kernel);
return compile_to_executable(kernel, vulkan_runtime_.get(),
snode_tree_mgr_.get());
Expand Down Expand Up @@ -209,7 +249,45 @@ std::unique_ptr<aot::Kernel> VulkanProgramImpl::make_aot_kernel(
std::move(kparams));
}

void VulkanProgramImpl::dump_cache_data_to_disk() {
if (offline_cache::enabled_wip_offline_cache(config->offline_cache)) {
auto path = offline_cache::get_cache_path_by_arch(
config->offline_cache_file_path, config->arch);
taichi::create_directories(path);
auto *cache_builder = static_cast<gfx::AotModuleBuilderImpl *>(
get_caching_module_builder().get());
cache_builder->mangle_aot_data();
cache_builder->merge_with_old_meta_data(path);
cache_builder->dump(path, "");
}
}

const std::unique_ptr<AotModuleBuilder>
&VulkanProgramImpl::get_caching_module_builder() {
if (!caching_module_builder_) {
caching_module_builder_ = make_aot_module_builder();
}
return caching_module_builder_;
}

const std::unique_ptr<aot::Module> &VulkanProgramImpl::get_cached_module() {
if (!cached_module_) {
auto path = offline_cache::get_cache_path_by_arch(
config->offline_cache_file_path, config->arch);
if (taichi::path_exists(taichi::join_path(path, "metadata.tcb")) &&
taichi::path_exists(taichi::join_path(path, "graphs.tcb"))) {
gfx::AotModuleParams params;
params.module_path = path;
params.runtime = vulkan_runtime_.get();
cached_module_ = gfx::make_aot_module(params, config->arch);
}
}
return cached_module_;
}

VulkanProgramImpl::~VulkanProgramImpl() {
caching_module_builder_.reset();
cached_module_.reset();
vulkan_runtime_.reset();
embedded_device_.reset();
}
Expand Down
8 changes: 8 additions & 0 deletions taichi/runtime/program_impls/vulkan/vulkan_program.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include "taichi/aot/module_loader.h"
#include "taichi/codegen/spirv/spirv_codegen.h"
#include "taichi/codegen/spirv/snode_struct_compiler.h"
#include "taichi/codegen/spirv/kernel_utils.h"
Expand Down Expand Up @@ -91,13 +92,20 @@ class VulkanProgramImpl : public ProgramImpl {

std::unique_ptr<aot::Kernel> make_aot_kernel(Kernel &kernel) override;

void dump_cache_data_to_disk() override;

const std::unique_ptr<AotModuleBuilder> &get_caching_module_builder();
const std::unique_ptr<aot::Module> &get_cached_module();

~VulkanProgramImpl();

private:
std::unique_ptr<vulkan::VulkanDeviceCreator> embedded_device_{nullptr};
std::unique_ptr<gfx::GfxRuntime> vulkan_runtime_{nullptr};
std::unique_ptr<gfx::SNodeTreeManager> snode_tree_mgr_{nullptr};
std::vector<spirv::CompiledSNodeStructs> aot_compiled_snode_structs_;
std::unique_ptr<AotModuleBuilder> caching_module_builder_{nullptr};
std::unique_ptr<aot::Module> cached_module_{nullptr};
};
} // namespace lang
} // namespace taichi
Loading

0 comments on commit 1b4006f

Please sign in to comment.