Skip to content

Commit

Permalink
[metal] Add kernel side memory allocator (#1175)
Browse files Browse the repository at this point in the history
* [metal] Add kernel side memory allocator

* disable fmt

* [skip ci] enforce code format

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
k-ye and taichi-gardener authored Jun 11, 2020
1 parent 2169c98 commit 482e82d
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 118 deletions.
19 changes: 11 additions & 8 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ constexpr char kLinearLoopIndexName[] = "linear_loop_idx_";
constexpr char kListgenElemVarName[] = "listgen_elem_";
constexpr char kRandStateVarName[] = "rand_state_";
constexpr char kSNodeMetaVarName[] = "sn_meta_";
constexpr char kMemAllocVarName[] = "mem_alloc_";

std::string buffer_to_name(BuffersEnum b) {
switch (b) {
Expand Down Expand Up @@ -757,11 +758,13 @@ class KernelCodegen : public IRVisitor {
emit("device Runtime *{} = reinterpret_cast<device Runtime *>({});",
kRuntimeVarName, kRuntimeBufferName);
emit(
"device byte *list_data_addr = reinterpret_cast<device byte *>({} + "
"1);",
kRuntimeVarName);
emit("device ListManager *parent_list = &({}->snode_lists[{}]);",
kRuntimeVarName, sn_id);
"device MemoryAllocator *{} = reinterpret_cast<device MemoryAllocator "
"*>({} + 1);",
kMemAllocVarName, kRuntimeVarName);
emit("ListManager parent_list;");
emit("parent_list.lm_data = ({}->snode_lists + {});", kRuntimeVarName,
sn_id);
emit("parent_list.mem_alloc = {};", kMemAllocVarName);
emit("const SNodeMeta parent_meta = {}->snode_metas[{}];", kRuntimeVarName,
sn_id);
emit("const int child_stride = parent_meta.element_stride;");
Expand All @@ -773,11 +776,11 @@ class KernelCodegen : public IRVisitor {
{
ScopedIndent s2(current_appender());
emit("const int parent_idx_ = (ii / child_num_slots);");
emit("if (parent_idx_ >= num_active(parent_list)) return;");
emit("if (parent_idx_ >= num_active(&parent_list)) return;");
emit("const int child_idx_ = (ii % child_num_slots);");
emit(
"const auto parent_elem_ = get<ListgenElement>(parent_list, "
"parent_idx_, list_data_addr);");
"const auto parent_elem_ = get<ListgenElement>(&parent_list, "
"parent_idx_);");

emit("ListgenElement {};", kListgenElemVarName);
// No need to add mem_offset_in_parent, because place() always starts at 0
Expand Down
79 changes: 50 additions & 29 deletions taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase {
public:
struct Params : public CompiledMtlKernelBase::Params {
MemoryPool *mem_pool = nullptr;
const SNodeDescriptorsMap *snode_descriptors = nullptr;

const SNode *snode() const {
return kernel_attribs->runtime_list_op_attribs.snode;
Expand All @@ -159,7 +160,7 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase {
parent_snode_id_(params.snode()->parent->id),
child_snode_id_(params.snode()->id),
args_mem_(std::make_unique<BufferMemoryView>(
/*size=*/sizeof(int32_t) * 2,
/*size=*/sizeof(int32_t) * 3,
params.mem_pool)),
args_buffer_(new_mtl_buffer_no_copy(params.device,
args_mem_->ptr(),
Expand All @@ -168,6 +169,8 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase {
auto *mem = reinterpret_cast<int32_t *>(args_mem_->ptr());
mem[0] = parent_snode_id_;
mem[1] = child_snode_id_;
const auto &sn_descs = *params.snode_descriptors;
mem[2] = total_num_self_from_root(sn_descs, child_snode_id_);
}

void launch(InputBuffersMap &input_buffers,
Expand All @@ -189,6 +192,7 @@ class RuntimeListOpsMtlKernel : public CompiledMtlKernelBase {
// For such Metal kernels, it always takes in an args buffer of two int32's:
// args[0] = parent_snode_id
// args[1] = child_snode_id
// args[2] = child_snode.total_num_self_from_root
// Note that this args buffer has nothing to do with the one passed to Taichi
// kernel.
// See taichi/backends/metal/shaders/runtime_kernels.metal.h
Expand All @@ -205,6 +209,7 @@ class CompiledTaichiKernel {
std::string mtl_source_code;
const std::vector<KernelAttributes> *mtl_kernels_attribs;
const KernelContextAttributes *ctx_attribs;
const SNodeDescriptorsMap *snode_descriptors;
MTLDevice *device;
MemoryPool *mem_pool;
ProfilerBase *profiler;
Expand All @@ -230,6 +235,7 @@ class CompiledTaichiKernel {
kparams.device = device;
kparams.mtl_func = mtl_func.get();
kparams.mem_pool = params.mem_pool;
kparams.snode_descriptors = params.snode_descriptors;
kernel = std::make_unique<RuntimeListOpsMtlKernel>(kparams);
} else {
UserMtlKernel::Params kparams;
Expand Down Expand Up @@ -413,18 +419,21 @@ class KernelManager::Impl {
device_.get(), global_tmps_mem_->ptr(), global_tmps_mem_->size());
TI_ASSERT(global_tmps_buffer_ != nullptr);

if (compiled_structs_.runtime_size > 0) {
runtime_mem_ = std::make_unique<BufferMemoryView>(
compiled_structs_.runtime_size, mem_pool_);
runtime_buffer_ = new_mtl_buffer_no_copy(
device_.get(), runtime_mem_->ptr(), runtime_mem_->size());
TI_DEBUG("Metal runtime buffer size: {} bytes", runtime_mem_->size());
TI_ASSERT_INFO(
runtime_buffer_ != nullptr,
"Failed to allocate Metal runtime buffer, requested {} bytes",
runtime_mem_->size());
init_runtime(params.root_id);
}
TI_ASSERT(compiled_structs_.runtime_size > 0);
const int mem_pool_bytes = (config_->device_memory_GB * 1024 * 1024 * 1024);
runtime_mem_ = std::make_unique<BufferMemoryView>(
compiled_structs_.runtime_size + mem_pool_bytes, mem_pool_);
runtime_buffer_ = new_mtl_buffer_no_copy(device_.get(), runtime_mem_->ptr(),
runtime_mem_->size());
TI_DEBUG(
"Metal runtime buffer size: {} bytes (sizeof(Runtime)={} "
"memory_pool={})",
runtime_mem_->size(), compiled_structs_.runtime_size, mem_pool_bytes);
TI_ASSERT_INFO(
runtime_buffer_ != nullptr,
"Failed to allocate Metal runtime buffer, requested {} bytes",
runtime_mem_->size());
init_runtime(params.root_id);
}

void register_taichi_kernel(
Expand All @@ -447,6 +456,7 @@ class KernelManager::Impl {
params.mtl_source_code = mtl_kernel_source_code;
params.mtl_kernels_attribs = &kernels_attribs;
params.ctx_attribs = &ctx_attribs;
params.snode_descriptors = &compiled_structs_.snode_descriptors;
params.device = device_.get();
params.mem_pool = mem_pool_;
params.profiler = profiler_;
Expand Down Expand Up @@ -560,32 +570,27 @@ class KernelManager::Impl {
TI_DEBUG("Initialized SNodeExtractors, size={} accumuated={}", addr_offset,
(addr - addr_begin));
// init snode_lists
ListManager *const rtm_list_head = reinterpret_cast<ListManager *>(addr);
int list_data_mem_begin = 0;
ListManagerData *const rtm_list_head =
reinterpret_cast<ListManagerData *>(addr);
for (int i = 0; i < max_snodes; ++i) {
auto iter = snode_descriptors.find(i);
if (iter == snode_descriptors.end()) {
continue;
}
const SNodeDescriptor &sn_desc = iter->second;
ListManager *rtm_list = reinterpret_cast<ListManager *>(addr) + i;
ListManagerData *rtm_list = reinterpret_cast<ListManagerData *>(addr) + i;
rtm_list->element_stride = sizeof(ListgenElement);
// This can be really large, especially for other sparse SNodes (e.g.
// dynamic, hash). In the future, Metal might also be able to support
// dynamic memory allocation from the kernel side. That should help reduce
// the initial size.
rtm_list->max_num_elems =
sn_desc.total_num_self_from_root(snode_descriptors);

const int num_elems_per_chunk = compute_num_elems_per_chunk(
sn_desc.total_num_self_from_root(snode_descriptors));
rtm_list->log2_num_elems_per_chunk = log2int(num_elems_per_chunk);
rtm_list->next = 0;
rtm_list->mem_begin = list_data_mem_begin;
list_data_mem_begin += rtm_list->max_num_elems * rtm_list->element_stride;
TI_DEBUG("ListManager\n id={}\n num_slots={}\n mem_begin={}\n", i,
rtm_list->max_num_elems, rtm_list->mem_begin);
TI_DEBUG("ListManagerData\n id={}\n num_elems_per_chunk={}\n", i,
num_elems_per_chunk);
}
addr_offset = sizeof(ListManager) * max_snodes;
addr_offset = sizeof(ListManagerData) * max_snodes;
addr += addr_offset;
TI_DEBUG("Initialized ListManager, size={} accumuated={}", addr_offset,
TI_DEBUG("Initialized ListManagerData, size={} accumuated={}", addr_offset,
(addr - addr_begin));
// init rand_seeds
// TODO(k-ye): Provide a way to use a fixed seed in dev mode.
Expand All @@ -604,14 +609,30 @@ class KernelManager::Impl {
kNumRandSeeds * sizeof(uint32_t), (addr - addr_begin));

if (compiled_structs_.need_snode_lists_data) {
auto *alloc = reinterpret_cast<MemoryAllocator *>(addr);
// Make sure the retured memory address is always greater than 1.
alloc->next = shaders::kAlignment;
// root list data are static
ListgenElement root_elem;
root_elem.root_mem_offset = 0;
for (int i = 0; i < taichi_max_num_indices; ++i) {
root_elem.coords[i] = 0;
}
append(rtm_list_head + root_id, root_elem, addr);
ListManager root_lm;
root_lm.lm_data = rtm_list_head + root_id;
root_lm.mem_alloc = alloc;
append(&root_lm, root_elem);
}
}

static int compute_num_elems_per_chunk(int n) {
const int lb =
(n + shaders::kTaichiNumChunks - 1) / shaders::kTaichiNumChunks;
int result = 1024;
while (result < lb) {
result <<= 1;
}
return result;
}

void create_new_command_buffer() {
Expand Down
1 change: 1 addition & 0 deletions taichi/backends/metal/shaders/atomic_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ template <typename T>
bool atomic_compare_exchange_weak_explicit(T *object,
T *expected,
T desired,
metal::memory_order,
metal::memory_order) {
const T val = *object;
if (val == *expected) {
Expand Down
36 changes: 20 additions & 16 deletions taichi/backends/metal/shaders/runtime_kernels.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ STR(
if (utid_ > 0)
return;
int child_snode_id = args[1];
device ListManager *child_list =
&(reinterpret_cast<device Runtime *>(runtime_addr)
->snode_lists[child_snode_id]);
clear(child_list);
ListManager child_list;
child_list.lm_data =
(reinterpret_cast<device Runtime *>(runtime_addr)->snode_lists +
child_snode_id);
clear(&child_list);
}

kernel void element_listgen(device byte *runtime_addr[[buffer(0)]],
Expand All @@ -62,30 +63,33 @@ STR(
const uint grid_size[[threads_per_grid]]) {
device Runtime *runtime =
reinterpret_cast<device Runtime *>(runtime_addr);
device byte *list_data_addr =
reinterpret_cast<device byte *>(runtime + 1);
device MemoryAllocator *mem_alloc =
reinterpret_cast<device MemoryAllocator *>(runtime + 1);

int parent_snode_id = args[0];
int child_snode_id = args[1];
device ListManager *parent_list =
&(runtime->snode_lists[parent_snode_id]);
device ListManager *child_list = &(runtime->snode_lists[child_snode_id]);
const int parent_snode_id = args[0];
const int child_snode_id = args[1];
ListManager parent_list;
parent_list.lm_data = (runtime->snode_lists + parent_snode_id);
parent_list.mem_alloc = mem_alloc;
ListManager child_list;
child_list.lm_data = (runtime->snode_lists + child_snode_id);
child_list.mem_alloc = mem_alloc;
const SNodeMeta parent_meta = runtime->snode_metas[parent_snode_id];
const int child_stride = parent_meta.element_stride;
const int num_slots = parent_meta.num_slots;
const SNodeMeta child_meta = runtime->snode_metas[child_snode_id];
// |max_num_elems| is NOT padded to power-of-two, while |num_slots| is.
// So we need to cap the loop precisely at child's |max_num_elems|.
for (int ii = utid_; ii < child_list->max_num_elems; ii += grid_size) {
const int max_num_elems = args[2];
for (int ii = utid_; ii < max_num_elems; ii += grid_size) {
const int parent_idx = (ii / num_slots);
if (parent_idx >= num_active(parent_list)) {
if (parent_idx >= num_active(&parent_list)) {
// Since |parent_idx| increases monotonically, we can return directly
// once it goes beyond the number of active parent elements.
return;
}
const int child_idx = (ii % num_slots);
const auto parent_elem =
get<ListgenElement>(parent_list, parent_idx, list_data_addr);
const auto parent_elem = get<ListgenElement>(&parent_list, parent_idx);
ListgenElement child_elem;
child_elem.root_mem_offset = parent_elem.root_mem_offset +
child_idx * child_stride +
Expand All @@ -95,7 +99,7 @@ STR(
refine_coordinates(parent_elem,
runtime->snode_extractors[parent_snode_id],
child_idx, &child_elem);
append(child_list, child_elem, list_data_addr);
append(&child_list, child_elem);
}
}
}
Expand Down
27 changes: 17 additions & 10 deletions taichi/backends/metal/shaders/runtime_structs.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

static_assert(taichi_max_num_indices == 8,
"Please update kTaichiMaxNumIndices");

static_assert(sizeof(char *) == 8, "Metal pointers are 64-bit.");
#define METAL_BEGIN_RUNTIME_STRUCTS_DEF
#define METAL_END_RUNTIME_STRUCTS_DEF

Expand All @@ -30,23 +30,30 @@ METAL_BEGIN_RUNTIME_STRUCTS_DEF
STR(
// clang-format on
constant constexpr int kTaichiMaxNumIndices = 8;
constant constexpr int kTaichiNumChunks = 1024;

struct MemoryAllocator { atomic_int next; };

struct ListgenElement {
int32_t coords[kTaichiMaxNumIndices];
int32_t root_mem_offset = 0;
};

// ListManager manages the activeness of its associated SNode.
struct ListManager {
// ListManagerData manages the activeness of its associated SNode.
struct ListManagerData {
int32_t element_stride = 0;
// Total number of this SNode in the hierarchy.
// Same as |total_num_self_from_root| of this SNode.
int32_t max_num_elems = 0;

int32_t log2_num_elems_per_chunk = 0;
// Index to the next element in this list.
// |next| can never go beyond |max_num_elems|.
int32_t next = 0;
// The data offset from the runtime memory beginning.
int32_t mem_begin = 0;
// |next| can never go beyond |kTaichiNumChunks| * |num_elems_per_chunk|.
atomic_int next;

atomic_int chunks[kTaichiNumChunks];
};

struct ListManager {
device ListManagerData *lm_data;
device MemoryAllocator *mem_alloc;
};

// This class is very similar to metal::SNodeDescriptor
Expand Down
Loading

0 comments on commit 482e82d

Please sign in to comment.