Skip to content

Commit

Permalink
[aot] [llvm] LLVM AOT Field #0: Implemented FieldCacheData & refactor…
Browse files Browse the repository at this point in the history
…ed initialize_llvm_runtime_snodes() (#5108)

* [aot] [llvm] Implemented FieldCacheData and refactored initialize_llvm_runtime_snodes()

* Addressed compilation erros

* Added initialization for struct members

* Minor fix
  • Loading branch information
jim19930609 authored Jun 10, 2022
1 parent 928aef1 commit 88f75a9
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 27 deletions.
32 changes: 31 additions & 1 deletion taichi/llvm/llvm_offline_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,37 @@ struct LlvmOfflineCache {
TI_IO_DEF(kernel_key, args, offloaded_task_list);
};

std::unordered_map<std::string, KernelCacheData> kernels;
struct FieldCacheData {
struct SNodeCacheData {
int id{0};
SNodeType type = SNodeType::undefined;
size_t cell_size_bytes{0};
size_t chunk_size{0};

TI_IO_DEF(id, type, cell_size_bytes, chunk_size);
};

int tree_id{0};
int root_id{0};
size_t root_size{0};
std::vector<SNodeCacheData> snode_metas;

TI_IO_DEF(tree_id, root_id, root_size, snode_metas);

// TODO(zhanlue)
// Serialize/Deserialize the llvm::Module from StructCompiler
// At runtime, make sure loaded Field-Modules and Kernel-Modules are linked
// altogether.
};

// TODO(zhanlue): we need a better identifier for each FieldCacheData
// (SNodeTree) Given that snode_tree_id is not continuous, it is ridiculous to
// ask the users to remember each of the snode_tree_ids
// ** Find a way to name each SNodeTree **
std::unordered_map<int, FieldCacheData> fields; // key = snode_tree_id

std::unordered_map<std::string, KernelCacheData>
kernels; // key = kernel_name

TI_IO_DEF(kernels);
};
Expand Down
71 changes: 48 additions & 23 deletions taichi/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ LlvmProgramImpl::clone_struct_compiler_initial_context(
return tlctx->clone_runtime_module();
}

void LlvmProgramImpl::initialize_llvm_runtime_snodes(const SNodeTree *tree,
StructCompiler *scomp,
uint64 *result_buffer) {
void LlvmProgramImpl::initialize_llvm_runtime_snodes(
const LlvmOfflineCache::FieldCacheData &field_cache_data,
uint64 *result_buffer) {
TaichiLLVMContext *tlctx = nullptr;
if (config->arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
Expand All @@ -175,15 +175,16 @@ void LlvmProgramImpl::initialize_llvm_runtime_snodes(const SNodeTree *tree,
auto *const runtime_jit = tlctx->runtime_jit_module;
// By the time this creator is called, "this" is already destroyed.
// Therefore it is necessary to capture members by values.
const auto snodes = scomp->snodes;
const int root_id = tree->root()->id;
size_t root_size = field_cache_data.root_size;
const auto snode_metas = field_cache_data.snode_metas;
const int tree_id = field_cache_data.tree_id;
const int root_id = field_cache_data.root_id;

TI_TRACE("Allocating data structure of size {} bytes", scomp->root_size);
std::size_t rounded_size =
taichi::iroundup(scomp->root_size, taichi_page_size);
TI_TRACE("Allocating data structure of size {} bytes", root_size);
std::size_t rounded_size = taichi::iroundup(root_size, taichi_page_size);

Ptr root_buffer = snode_tree_buffer_manager_->allocate(
runtime_jit, llvm_runtime_, rounded_size, taichi_page_size, tree->id(),
runtime_jit, llvm_runtime_, rounded_size, taichi_page_size, tree_id,
result_buffer);
if (config->arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
Expand All @@ -207,33 +208,33 @@ void LlvmProgramImpl::initialize_llvm_runtime_snodes(const SNodeTree *tree,
alloc = cpu_device()->import_memory(root_buffer, rounded_size);
}

snode_tree_allocs_[tree->id()] = alloc;
snode_tree_allocs_[tree_id] = alloc;

bool all_dense = config->demote_dense_struct_fors;
for (int i = 0; i < (int)snodes.size(); i++) {
if (snodes[i]->type != SNodeType::dense &&
snodes[i]->type != SNodeType::place &&
snodes[i]->type != SNodeType::root) {
for (size_t i = 0; i < snode_metas.size(); i++) {
if (snode_metas[i].type != SNodeType::dense &&
snode_metas[i].type != SNodeType::place &&
snode_metas[i].type != SNodeType::root) {
all_dense = false;
break;
}
}

runtime_jit->call<void *, std::size_t, int, int, int, std::size_t, Ptr>(
"runtime_initialize_snodes", llvm_runtime_, scomp->root_size, root_id,
(int)snodes.size(), tree->id(), rounded_size, root_buffer, all_dense);
"runtime_initialize_snodes", llvm_runtime_, root_size, root_id,
(int)snode_metas.size(), tree_id, rounded_size, root_buffer, all_dense);

for (int i = 0; i < (int)snodes.size(); i++) {
if (is_gc_able(snodes[i]->type)) {
const auto snode_id = snodes[i]->id;
for (size_t i = 0; i < snode_metas.size(); i++) {
if (is_gc_able(snode_metas[i].type)) {
const auto snode_id = snode_metas[i].id;
std::size_t node_size;
auto element_size = snodes[i]->cell_size_bytes;
if (snodes[i]->type == SNodeType::pointer) {
auto element_size = snode_metas[i].cell_size_bytes;
if (snode_metas[i].type == SNodeType::pointer) {
// pointer. Allocators are for single elements
node_size = element_size;
} else {
// dynamic. Allocators are for the chunks
node_size = sizeof(void *) + element_size * snodes[i]->chunk_size;
node_size = sizeof(void *) + element_size * snode_metas[i].chunk_size;
}
TI_TRACE("Initializing allocator for snode {} (node size {})", snode_id,
node_size);
Expand Down Expand Up @@ -275,10 +276,34 @@ void LlvmProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
compile_snode_tree_types_impl(tree);
}

static LlvmOfflineCache::FieldCacheData construct_filed_cache_data(
const SNodeTree &tree,
const StructCompiler &struct_compiler) {
LlvmOfflineCache::FieldCacheData ret;
ret.tree_id = tree.id();
ret.root_id = tree.root()->id;
ret.root_size = struct_compiler.root_size;

const auto &snodes = struct_compiler.snodes;
for (size_t i = 0; i < snodes.size(); i++) {
LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data;
snode_cache_data.id = snodes[i]->id;
snode_cache_data.type = snodes[i]->type;
snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes;
snode_cache_data.chunk_size = snodes[i]->chunk_size;

ret.snode_metas.emplace_back(std::move(snode_cache_data));
}

return ret;
}

void LlvmProgramImpl::materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer) {
auto struct_compiler = compile_snode_tree_types_impl(tree);
initialize_llvm_runtime_snodes(tree, struct_compiler.get(), result_buffer);

auto field_cache_data = construct_filed_cache_data(*tree, *struct_compiler);
initialize_llvm_runtime_snodes(field_cache_data, result_buffer);
}

uint64 LlvmProgramImpl::fetch_result_uint64(int i, uint64 *result_buffer) {
Expand Down
6 changes: 3 additions & 3 deletions taichi/llvm/llvm_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ class LlvmProgramImpl : public ProgramImpl {
/**
* Initializes the SNodes for LLVM based backends.
*/
void initialize_llvm_runtime_snodes(const SNodeTree *tree,
StructCompiler *scomp,
uint64 *result_buffer);
void initialize_llvm_runtime_snodes(
const LlvmOfflineCache::FieldCacheData &field_cache_data,
uint64 *result_buffer);

uint64 fetch_result_uint64(int i, uint64 *result_buffer);

Expand Down

0 comments on commit 88f75a9

Please sign in to comment.