From 482e82dcb5e2533d5d7a1631bacb1a98d4b665a5 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Thu, 11 Jun 2020 20:47:00 +0900 Subject: [PATCH] [metal] Add kernel side memory allocator (#1175) * [metal] Add kernel side memory allocator * disable fmt * [skip ci] enforce code format Co-authored-by: Taichi Gardener --- taichi/backends/metal/codegen_metal.cpp | 19 +-- taichi/backends/metal/kernel_manager.cpp | 79 +++++++----- taichi/backends/metal/shaders/atomic_stubs.h | 1 + .../metal/shaders/runtime_kernels.metal.h | 36 +++--- .../metal/shaders/runtime_structs.metal.h | 27 +++-- .../metal/shaders/runtime_utils.metal.h | 112 ++++++++++++++---- taichi/backends/metal/struct_metal.cpp | 21 +--- taichi/backends/metal/struct_metal.h | 29 ++--- 8 files changed, 206 insertions(+), 118 deletions(-) diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 77c75f4c76e7c..22aa30f0e9752 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -36,6 +36,7 @@ constexpr char kLinearLoopIndexName[] = "linear_loop_idx_"; constexpr char kListgenElemVarName[] = "listgen_elem_"; constexpr char kRandStateVarName[] = "rand_state_"; constexpr char kSNodeMetaVarName[] = "sn_meta_"; +constexpr char kMemAllocVarName[] = "mem_alloc_"; std::string buffer_to_name(BuffersEnum b) { switch (b) { @@ -757,11 +758,13 @@ class KernelCodegen : public IRVisitor { emit("device Runtime *{} = reinterpret_cast({});", kRuntimeVarName, kRuntimeBufferName); emit( - "device byte *list_data_addr = reinterpret_cast({} + " - "1);", - kRuntimeVarName); - emit("device ListManager *parent_list = &({}->snode_lists[{}]);", - kRuntimeVarName, sn_id); + "device MemoryAllocator *{} = reinterpret_cast({} + 1);", + kMemAllocVarName, kRuntimeVarName); + emit("ListManager parent_list;"); + emit("parent_list.lm_data = ({}->snode_lists + {});", kRuntimeVarName, + sn_id); + emit("parent_list.mem_alloc = {};", kMemAllocVarName); emit("const SNodeMeta parent_meta = {}->snode_metas[{}];", kRuntimeVarName, sn_id); emit("const int child_stride = parent_meta.element_stride;"); @@ -773,11 +776,11 @@ class KernelCodegen : public IRVisitor { { ScopedIndent s2(current_appender()); emit("const int parent_idx_ = (ii / child_num_slots);"); - emit("if (parent_idx_ >= num_active(parent_list)) return;"); + emit("if (parent_idx_ >= num_active(&parent_list)) return;"); emit("const int child_idx_ = (ii % child_num_slots);"); emit( - "const auto parent_elem_ = get(parent_list, " - "parent_idx_, list_data_addr);"); + "const auto parent_elem_ = get(&parent_list, " + "parent_idx_);"); emit("ListgenElement {};", kListgenElemVarName); // No need to add mem_offset_in_parent, because place() always starts at 0 diff --git a/taichi/backends/metal/kernel_manager.cpp b/taichi/backends/metal/kernel_manager.cpp index b5c605c7d4c67..fc6ed5e1b26f0 100644 --- a/taichi/backends/metal/kernel_manager.cpp +++ b/taichi/backends/metal/kernel_manager.cpp @@ -148,6 +148,7 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase { public: struct Params : public CompiledMtlKernelBase::Params { MemoryPool *mem_pool = nullptr; + const SNodeDescriptorsMap *snode_descriptors = nullptr; const SNode *snode() const { return kernel_attribs->runtime_list_op_attribs.snode; @@ -159,7 +160,7 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase { parent_snode_id_(params.snode()->parent->id), child_snode_id_(params.snode()->id), args_mem_(std::make_unique( - /*size=*/sizeof(int32_t) * 2, + /*size=*/sizeof(int32_t) * 3, params.mem_pool)), args_buffer_(new_mtl_buffer_no_copy(params.device, args_mem_->ptr(), @@ -168,6 +169,8 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase { auto *mem = reinterpret_cast(args_mem_->ptr()); mem[0] = parent_snode_id_; mem[1] = child_snode_id_; + const auto &sn_descs = *params.snode_descriptors; + mem[2] = total_num_self_from_root(sn_descs, child_snode_id_); } void launch(InputBuffersMap &input_buffers, @@ -189,6 +192,7 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase { // For such Metal kernels, it always takes in an args buffer of two int32's: // args[0] = parent_snode_id // args[1] = child_snode_id + // args[2] = child_snode.total_num_self_from_root // Note that this args buffer has nothing to do with the one passed to Taichi // kernel. // See taichi/backends/metal/shaders/runtime_kernels.metal.h @@ -205,6 +209,7 @@ class CompiledTaichiKernel { std::string mtl_source_code; const std::vector *mtl_kernels_attribs; const KernelContextAttributes *ctx_attribs; + const SNodeDescriptorsMap *snode_descriptors; MTLDevice *device; MemoryPool *mem_pool; ProfilerBase *profiler; @@ -230,6 +235,7 @@ class CompiledTaichiKernel { kparams.device = device; kparams.mtl_func = mtl_func.get(); kparams.mem_pool = params.mem_pool; + kparams.snode_descriptors = params.snode_descriptors; kernel = std::make_unique(kparams); } else { UserMtlKernel::Params kparams; @@ -413,18 +419,21 @@ class KernelManager::Impl { device_.get(), global_tmps_mem_->ptr(), global_tmps_mem_->size()); TI_ASSERT(global_tmps_buffer_ != nullptr); - if (compiled_structs_.runtime_size > 0) { - runtime_mem_ = std::make_unique( - compiled_structs_.runtime_size, mem_pool_); - runtime_buffer_ = new_mtl_buffer_no_copy( - device_.get(), runtime_mem_->ptr(), runtime_mem_->size()); - TI_DEBUG("Metal runtime buffer size: {} bytes", runtime_mem_->size()); - TI_ASSERT_INFO( - runtime_buffer_ != nullptr, - "Failed to allocate Metal runtime buffer, requested {} bytes", - runtime_mem_->size()); - init_runtime(params.root_id); - } + TI_ASSERT(compiled_structs_.runtime_size > 0); + const int mem_pool_bytes = (config_->device_memory_GB * 1024 * 1024 * 1024); + runtime_mem_ = std::make_unique( + compiled_structs_.runtime_size + mem_pool_bytes, mem_pool_); + runtime_buffer_ = new_mtl_buffer_no_copy(device_.get(), runtime_mem_->ptr(), + runtime_mem_->size()); + TI_DEBUG( + "Metal runtime buffer size: {} bytes (sizeof(Runtime)={} " + "memory_pool={})", + runtime_mem_->size(), compiled_structs_.runtime_size, mem_pool_bytes); + TI_ASSERT_INFO( + runtime_buffer_ != nullptr, + "Failed to allocate Metal runtime buffer, requested {} bytes", + runtime_mem_->size()); + init_runtime(params.root_id); } void register_taichi_kernel( @@ -447,6 +456,7 @@ class KernelManager::Impl { params.mtl_source_code = mtl_kernel_source_code; params.mtl_kernels_attribs = &kernels_attribs; params.ctx_attribs = &ctx_attribs; + params.snode_descriptors = &compiled_structs_.snode_descriptors; params.device = device_.get(); params.mem_pool = mem_pool_; params.profiler = profiler_; @@ -560,32 +570,27 @@ class KernelManager::Impl { TI_DEBUG("Initialized SNodeExtractors, size={} accumuated={}", addr_offset, (addr - addr_begin)); // init snode_lists - ListManager *const rtm_list_head = reinterpret_cast(addr); - int list_data_mem_begin = 0; + ListManagerData *const rtm_list_head = + reinterpret_cast(addr); for (int i = 0; i < max_snodes; ++i) { auto iter = snode_descriptors.find(i); if (iter == snode_descriptors.end()) { continue; } const SNodeDescriptor &sn_desc = iter->second; - ListManager *rtm_list = reinterpret_cast(addr) + i; + ListManagerData *rtm_list = reinterpret_cast(addr) + i; rtm_list->element_stride = sizeof(ListgenElement); - // This can be really large, especially for other sparse SNodes (e.g. - // dynamic, hash). In the future, Metal might also be able to support - // dynamic memory allocation from the kernel side. That should help reduce - // the initial size. - rtm_list->max_num_elems = - sn_desc.total_num_self_from_root(snode_descriptors); + const int num_elems_per_chunk = compute_num_elems_per_chunk( + sn_desc.total_num_self_from_root(snode_descriptors)); + rtm_list->log2_num_elems_per_chunk = log2int(num_elems_per_chunk); rtm_list->next = 0; - rtm_list->mem_begin = list_data_mem_begin; - list_data_mem_begin += rtm_list->max_num_elems * rtm_list->element_stride; - TI_DEBUG("ListManager\n id={}\n num_slots={}\n mem_begin={}\n", i, - rtm_list->max_num_elems, rtm_list->mem_begin); + TI_DEBUG("ListManagerData\n id={}\n num_elems_per_chunk={}\n", i, + num_elems_per_chunk); } - addr_offset = sizeof(ListManager) * max_snodes; + addr_offset = sizeof(ListManagerData) * max_snodes; addr += addr_offset; - TI_DEBUG("Initialized ListManager, size={} accumuated={}", addr_offset, + TI_DEBUG("Initialized ListManagerData, size={} accumuated={}", addr_offset, (addr - addr_begin)); // init rand_seeds // TODO(k-ye): Provide a way to use a fixed seed in dev mode. @@ -604,14 +609,30 @@ class KernelManager::Impl { kNumRandSeeds * sizeof(uint32_t), (addr - addr_begin)); if (compiled_structs_.need_snode_lists_data) { + auto *alloc = reinterpret_cast(addr); + // Make sure the retured memory address is always greater than 1. + alloc->next = shaders::kAlignment; // root list data are static ListgenElement root_elem; root_elem.root_mem_offset = 0; for (int i = 0; i < taichi_max_num_indices; ++i) { root_elem.coords[i] = 0; } - append(rtm_list_head + root_id, root_elem, addr); + ListManager root_lm; + root_lm.lm_data = rtm_list_head + root_id; + root_lm.mem_alloc = alloc; + append(&root_lm, root_elem); + } + } + + static int compute_num_elems_per_chunk(int n) { + const int lb = + (n + shaders::kTaichiNumChunks - 1) / shaders::kTaichiNumChunks; + int result = 1024; + while (result < lb) { + result <<= 1; } + return result; } void create_new_command_buffer() { diff --git a/taichi/backends/metal/shaders/atomic_stubs.h b/taichi/backends/metal/shaders/atomic_stubs.h index 11f7f1fead750..fbe0b8534ec94 100644 --- a/taichi/backends/metal/shaders/atomic_stubs.h +++ b/taichi/backends/metal/shaders/atomic_stubs.h @@ -14,6 +14,7 @@ template bool atomic_compare_exchange_weak_explicit(T *object, T *expected, T desired, + metal::memory_order, metal::memory_order) { const T val = *object; if (val == *expected) { diff --git a/taichi/backends/metal/shaders/runtime_kernels.metal.h b/taichi/backends/metal/shaders/runtime_kernels.metal.h index 99e96cf8bf68f..ce94e51e68dbc 100644 --- a/taichi/backends/metal/shaders/runtime_kernels.metal.h +++ b/taichi/backends/metal/shaders/runtime_kernels.metal.h @@ -49,10 +49,11 @@ STR( if (utid_ > 0) return; int child_snode_id = args[1]; - device ListManager *child_list = - &(reinterpret_cast(runtime_addr) - ->snode_lists[child_snode_id]); - clear(child_list); + ListManager child_list; + child_list.lm_data = + (reinterpret_cast(runtime_addr)->snode_lists + + child_snode_id); + clear(&child_list); } kernel void element_listgen(device byte *runtime_addr[[buffer(0)]], @@ -62,30 +63,33 @@ STR( const uint grid_size[[threads_per_grid]]) { device Runtime *runtime = reinterpret_cast(runtime_addr); - device byte *list_data_addr = - reinterpret_cast(runtime + 1); + device MemoryAllocator *mem_alloc = + reinterpret_cast(runtime + 1); - int parent_snode_id = args[0]; - int child_snode_id = args[1]; - device ListManager *parent_list = - &(runtime->snode_lists[parent_snode_id]); - device ListManager *child_list = &(runtime->snode_lists[child_snode_id]); + const int parent_snode_id = args[0]; + const int child_snode_id = args[1]; + ListManager parent_list; + parent_list.lm_data = (runtime->snode_lists + parent_snode_id); + parent_list.mem_alloc = mem_alloc; + ListManager child_list; + child_list.lm_data = (runtime->snode_lists + child_snode_id); + child_list.mem_alloc = mem_alloc; const SNodeMeta parent_meta = runtime->snode_metas[parent_snode_id]; const int child_stride = parent_meta.element_stride; const int num_slots = parent_meta.num_slots; const SNodeMeta child_meta = runtime->snode_metas[child_snode_id]; // |max_num_elems| is NOT padded to power-of-two, while |num_slots| is. // So we need to cap the loop precisely at child's |max_num_elems|. - for (int ii = utid_; ii < child_list->max_num_elems; ii += grid_size) { + const int max_num_elems = args[2]; + for (int ii = utid_; ii < max_num_elems; ii += grid_size) { const int parent_idx = (ii / num_slots); - if (parent_idx >= num_active(parent_list)) { + if (parent_idx >= num_active(&parent_list)) { // Since |parent_idx| increases monotonically, we can return directly // once it goes beyond the number of active parent elements. return; } const int child_idx = (ii % num_slots); - const auto parent_elem = - get(parent_list, parent_idx, list_data_addr); + const auto parent_elem = get(&parent_list, parent_idx); ListgenElement child_elem; child_elem.root_mem_offset = parent_elem.root_mem_offset + child_idx * child_stride + @@ -95,7 +99,7 @@ STR( refine_coordinates(parent_elem, runtime->snode_extractors[parent_snode_id], child_idx, &child_elem); - append(child_list, child_elem, list_data_addr); + append(&child_list, child_elem); } } } diff --git a/taichi/backends/metal/shaders/runtime_structs.metal.h b/taichi/backends/metal/shaders/runtime_structs.metal.h index 2b6220abc3726..578648fca93dc 100644 --- a/taichi/backends/metal/shaders/runtime_structs.metal.h +++ b/taichi/backends/metal/shaders/runtime_structs.metal.h @@ -19,7 +19,7 @@ static_assert(taichi_max_num_indices == 8, "Please update kTaichiMaxNumIndices"); - +static_assert(sizeof(char *) == 8, "Metal pointers are 64-bit."); #define METAL_BEGIN_RUNTIME_STRUCTS_DEF #define METAL_END_RUNTIME_STRUCTS_DEF @@ -30,23 +30,30 @@ METAL_BEGIN_RUNTIME_STRUCTS_DEF STR( // clang-format on constant constexpr int kTaichiMaxNumIndices = 8; + constant constexpr int kTaichiNumChunks = 1024; + + struct MemoryAllocator { atomic_int next; }; struct ListgenElement { int32_t coords[kTaichiMaxNumIndices]; int32_t root_mem_offset = 0; }; - // ListManager manages the activeness of its associated SNode. - struct ListManager { + // ListManagerData manages the activeness of its associated SNode. + struct ListManagerData { int32_t element_stride = 0; - // Total number of this SNode in the hierarchy. - // Same as |total_num_self_from_root| of this SNode. - int32_t max_num_elems = 0; + + int32_t log2_num_elems_per_chunk = 0; // Index to the next element in this list. - // |next| can never go beyond |max_num_elems|. - int32_t next = 0; - // The data offset from the runtime memory beginning. - int32_t mem_begin = 0; + // |next| can never go beyond |kTaichiNumChunks| * |num_elems_per_chunk|. + atomic_int next; + + atomic_int chunks[kTaichiNumChunks]; + }; + + struct ListManager { + device ListManagerData *lm_data; + device MemoryAllocator *mem_alloc; }; // This class is very similar to metal::SNodeDescriptor diff --git a/taichi/backends/metal/shaders/runtime_utils.metal.h b/taichi/backends/metal/shaders/runtime_utils.metal.h index 63197386d0f68..36e53cb790c04 100644 --- a/taichi/backends/metal/shaders/runtime_utils.metal.h +++ b/taichi/backends/metal/shaders/runtime_utils.metal.h @@ -31,38 +31,100 @@ // clang-format off METAL_BEGIN_RUNTIME_UTILS_DEF STR( - [[maybe_unused]] int num_active(device const ListManager *list) { - return list->next; + using PtrOffset = int32_t; + constant constexpr int kAlignment = 8; + + [[maybe_unused]] PtrOffset mtl_memalloc_alloc(device MemoryAllocator * ma, + int32_t size) { + size = ((size + kAlignment - 1) / kAlignment) * kAlignment; + return atomic_fetch_add_explicit(&ma->next, size, + metal::memory_order_relaxed); + } + + [[maybe_unused]] device char + *mtl_memalloc_to_ptr(device MemoryAllocator *ma, PtrOffset offs) { + return reinterpret_cast(ma + 1) + offs; + } + + [[maybe_unused]] int num_active(thread ListManager *l) { + return atomic_load_explicit(&(l->lm_data->next), + metal::memory_order_relaxed); + } + + [[maybe_unused]] void clear(thread ListManager *l) { + atomic_store_explicit(&(l->lm_data->next), 0, + metal::memory_order_relaxed); + } + + [[maybe_unused]] PtrOffset mtl_listmgr_ensure_chunk(thread ListManager *l, + int i) { + device ListManagerData *list = l->lm_data; + PtrOffset offs = 0; + const int kChunkBytes = + (list->element_stride << list->log2_num_elems_per_chunk); + + while (true) { + int stored = 0; + // If chunks[i] is unallocated, i.e. 0, mark it as 1 to prevent others + // from requesting memory again. Once allocated, set chunks[i] to the + // actual address offset, which is guaranteed to be greater than 1. + const bool is_me = atomic_compare_exchange_weak_explicit( + list->chunks + i, &stored, 1, metal::memory_order_relaxed, + metal::memory_order_relaxed); + if (is_me) { + offs = mtl_memalloc_alloc(l->mem_alloc, kChunkBytes); + atomic_store_explicit(list->chunks + i, offs, + metal::memory_order_relaxed); + break; + } else if (stored > 1) { + offs = stored; + break; + } + // |stored| == 1, just spin + } + return offs; + } + + [[maybe_unused]] device char *mtl_listmgr_get_elem_from_chunk( + thread ListManager *l, + int i, + PtrOffset chunk_ptr_offs) { + device ListManagerData *list = l->lm_data; + device char *chunk_ptr = reinterpret_cast( + mtl_memalloc_to_ptr(l->mem_alloc, chunk_ptr_offs)); + const uint32_t mask = ((1 << list->log2_num_elems_per_chunk) - 1); + return chunk_ptr + ((i & mask) * list->element_stride); + } + + [[maybe_unused]] device char *append(thread ListManager *l) { + device ListManagerData *list = l->lm_data; + const int elem_idx = atomic_fetch_add_explicit( + &list->next, 1, metal::memory_order_relaxed); + const int chunk_idx = elem_idx >> list->log2_num_elems_per_chunk; + const PtrOffset chunk_ptr_offs = mtl_listmgr_ensure_chunk(l, chunk_idx); + return mtl_listmgr_get_elem_from_chunk(l, elem_idx, chunk_ptr_offs); } template - int append(device ListManager *list, - thread const T &elem, - device byte *data_addr) { + [[maybe_unused]] void append(thread ListManager *l, thread const T &elem) { + device char *ptr = append(l); thread char *elem_ptr = (thread char *)(&elem); - int me = atomic_fetch_add_explicit( - reinterpret_cast(&(list->next)), 1, - metal::memory_order_relaxed); - device byte *ptr = - data_addr + list->mem_begin + (me * list->element_stride); - for (int i = 0; i < list->element_stride; ++i) { + + for (int i = 0; i < l->lm_data->element_stride; ++i) { *ptr = *elem_ptr; ++ptr; ++elem_ptr; } - return me; } template - T get(const device ListManager *list, int i, device const byte *data_addr) { - return *reinterpret_cast(data_addr + list->mem_begin + - (i * list->element_stride)); - } - - [[maybe_unused]] void clear(device ListManager *list) { - atomic_store_explicit( - reinterpret_cast(&(list->next)), 0, - metal::memory_order_relaxed); + [[maybe_unused]] T get(thread ListManager *l, int i) { + device ListManagerData *list = l->lm_data; + const int chunk_idx = i >> list->log2_num_elems_per_chunk; + const PtrOffset chunk_ptr_offs = atomic_load_explicit( + list->chunks + chunk_idx, metal::memory_order_relaxed); + return *reinterpret_cast( + mtl_listmgr_get_elem_from_chunk(l, i, chunk_ptr_offs)); } [[maybe_unused]] int is_active(device byte *addr, SNodeMeta meta, int i) { @@ -70,7 +132,7 @@ STR( return true; } device auto *meta_ptr_begin = reinterpret_cast( - addr + ((meta.num_slots - i) * meta.element_stride)); + addr + ((meta.num_slots - i) * meta.element_stride)); if (meta.type == SNodeMeta::Dynamic) { device auto *ptr = meta_ptr_begin; uint32_t n = atomic_load_explicit(ptr, metal::memory_order_relaxed); @@ -104,7 +166,7 @@ STR( return; } device auto *meta_ptr_begin = reinterpret_cast( - addr + ((meta.num_slots - i) * meta.element_stride)); + addr + ((meta.num_slots - i) * meta.element_stride)); if (meta.type == SNodeMeta::Dynamic) { device auto *ptr = meta_ptr_begin; // For dynamic, deactivate() applies for all the slots @@ -130,8 +192,8 @@ STR( } [[maybe_unused]] int dynamic_append(device byte *addr, - SNodeMeta meta, - int32_t data) { + SNodeMeta meta, + int32_t data) { // |addr| always starts at the beginning of the dynamic device auto *n_ptr = reinterpret_cast( addr + (meta.num_slots * meta.element_stride)); diff --git a/taichi/backends/metal/struct_metal.cpp b/taichi/backends/metal/struct_metal.cpp index 0e9f5baee39d2..9d976998f5e5f 100644 --- a/taichi/backends/metal/struct_metal.cpp +++ b/taichi/backends/metal/struct_metal.cpp @@ -26,7 +26,6 @@ namespace shaders { } // namespace shaders -constexpr size_t kListgenElementSize = sizeof(shaders::ListgenElement); constexpr size_t kListManagerSize = sizeof(shaders::ListManager); constexpr size_t kSNodeMetaSize = sizeof(shaders::SNodeMeta); constexpr size_t kSNodeExtractorsSize = sizeof(shaders::SNodeExtractors); @@ -221,7 +220,7 @@ class StructCompiler { emit("struct Runtime {{"); emit(" SNodeMeta snode_metas[{}];", max_snodes_); emit(" SNodeExtractors snode_extractors[{}];", max_snodes_); - emit(" ListManager snode_lists[{}];", max_snodes_); + emit(" ListManagerData snode_lists[{}];", max_snodes_); emit(" uint32_t rand_seeds[{}];", kNumRandSeeds); emit("}};"); } @@ -230,20 +229,6 @@ class StructCompiler { size_t result = (max_snodes_) * (kSNodeMetaSize + kSNodeExtractorsSize + kListManagerSize); result += sizeof(uint32_t) * kNumRandSeeds; - TI_DEBUG("Metal sizeof(Runtime): {} bytes", result); - if (has_sparse_snode_) { - // We only need additional memory to hold sparsity information. Don't - // allocate it if there is no sparse SNode at all. - int total_items = 0; - for (const auto &kv : snode_descriptors_) { - total_items += kv.second.total_num_self_from_root(snode_descriptors_); - } - const size_t list_data_size = total_items * kListgenElementSize; - TI_DEBUG("Metal runtime sparse list data size: {} bytes", list_data_size); - result += list_data_size; - } else { - TI_TRACE("Metal runtime doesn't need additional memory for snode_lists"); - } return result; } @@ -271,6 +256,10 @@ int SNodeDescriptor::total_num_self_from_root( return sn_descs.find(psn->id)->second.total_num_elems_from_root; } +int total_num_self_from_root(const SNodeDescriptorsMap &m, int snode_id) { + return m.at(snode_id).total_num_self_from_root(m); +} + CompiledStructs compile_structs(SNode &root) { return StructCompiler().run(root); } diff --git a/taichi/backends/metal/struct_metal.h b/taichi/backends/metal/struct_metal.h index b11e26a8f1d92..8b59935d688a7 100644 --- a/taichi/backends/metal/struct_metal.h +++ b/taichi/backends/metal/struct_metal.h @@ -36,6 +36,11 @@ struct SNodeDescriptor { const std::unordered_map &sn_descs) const; }; +using SNodeDescriptorsMap = std::unordered_map; + +// See SNodeDescriptor::total_num_self_from_root +int total_num_self_from_root(const SNodeDescriptorsMap &m, int snode_id); + struct CompiledStructs { // Source code of the SNode data structures compiled to Metal std::string snode_structs_source_code; @@ -49,34 +54,30 @@ struct CompiledStructs { // struct Runtime { // SNodeMeta snode_metas[max_snodes]; // SNodeExtractors snode_extractors[max_snodes]; - // ListManager snode_lists[max_snodes]; + // ListManagerData snode_lists[max_snodes]; // uint32_t rand_seeds[kNumRandSeeds]; // }; // - // If |need_snode_lists_data| is `true`, |runtime_size| will be greater than - // sizeof(Runtime). This is because the memory is divided into two parts. The - // first part, with size being sizeof(Runtime), is used to hold the Runtime - // struct as expected. The second part is used to hold the data of - // |snode_lists|. - // - // |---- Runtime ----|--------------- |snode_lists| data ---------------| - // |<------------------------- runtime_size --------------------------->| + // |runtime_size| will be sizeof(Runtime), which is useful for allocating the + // buffer memory. // - // The actual data address for the i-th ListManager is then: - // runtime memory address + list[i].mem_begin + // If |need_snode_lists_data| is true, the buffer will consist of two parts. + // The first part, with size being |runtime_size|, is used to hold the Runtime + // struct as expected. The second part is used as a kernel-side memory pool. // - // Otherwise if |need_snode_lists_data| is `false`, |runtime_size| will be - // equal to sizeof(Runtime). + // |------ Runtime -----|--------------- Metal memory pool ---------------| + // |<-- runtime_size -->|<------- decided by config, usually ~GB -------->| // // TODO(k-ye): See if Metal ArgumentBuffer can directly store the pointers. size_t runtime_size; // In case there is no sparse SNode (e.g. bitmasked), we don't need to // allocate the additional memory for |snode_lists|. + // TODO(k-ye): Rename to |needs_kernel_memory_allocator|. bool need_snode_lists_data; // max(ID of Root or Dense Snode) + 1 int max_snodes; // Map from SNode ID to its descriptor. - std::unordered_map snode_descriptors; + SNodeDescriptorsMap snode_descriptors; }; // Compile all snodes to Metal source code