Skip to content

Commit

Permalink
migrate all buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye committed Apr 28, 2022
1 parent afa9524 commit 5d086dc
Showing 1 changed file with 31 additions and 36 deletions.
67 changes: 31 additions & 36 deletions taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,34 +576,29 @@ class KernelManager::Impl {

const size_t mem_pool_bytes =
(config_->device_memory_GB * 1024 * 1024 * 1024ULL);
runtime_mem_ = std::make_unique<BufferMemoryView>(
compiled_runtime_module_.runtime_size + mem_pool_bytes, mem_pool_);
runtime_buffer_ = new_mtl_buffer_no_copy(device_.get(), runtime_mem_->ptr(),
runtime_mem_->size());
runtime_idevalloc_ =
make_idevalloc(compiled_runtime_module_.runtime_size + mem_pool_bytes);
buffer_meta_data_.runtime_buffer_size =
compiled_runtime_module_.runtime_size;
TI_DEBUG(
"Metal runtime buffer size: {} bytes (sizeof(Runtime)={} "
"memory_pool={})",
runtime_mem_->size(), compiled_runtime_module_.runtime_size,
runtime_idevalloc_.mem->size(), compiled_runtime_module_.runtime_size,
mem_pool_bytes);

ActionRecorder::get_instance().record(
"allocate_runtime_buffer",
{ActionArg("runtime_buffer_size_in_bytes", (int64)runtime_mem_->size()),
{ActionArg("runtime_buffer_size_in_bytes",
(int64)runtime_idevalloc_.mem->size()),
ActionArg("runtime_size_in_bytes",
(int64)compiled_runtime_module_.runtime_size),
ActionArg("memory_pool_size", (int64)mem_pool_bytes)});

TI_ASSERT_INFO(
runtime_buffer_ != nullptr,
"Failed to allocate Metal runtime buffer, requested {} bytes",
runtime_mem_->size());
print_mem_ = std::make_unique<BufferMemoryView>(
shaders::kMetalPrintAssertBufferSize, mem_pool_);
print_buffer_ = new_mtl_buffer_no_copy(device_.get(), print_mem_->ptr(),
print_mem_->size());
TI_ASSERT(print_buffer_ != nullptr);
TI_ASSERT_INFO(runtime_idevalloc_.buffer != nullptr,
"Failed to allocate Metal runtime buffer");
print_assert_idevalloc_ =
make_idevalloc(shaders::kMetalPrintAssertBufferSize);
TI_ASSERT(print_assert_idevalloc_.buffer != nullptr);

init_runtime_buffer(compiled_runtime_module_, params.config->random_seed);
clear_print_assert_buffer();
Expand All @@ -614,6 +609,8 @@ class KernelManager::Impl {
rhi_device_->dealloc_memory(rb.idevalloc.handle);
}
rhi_device_->dealloc_memory(global_tmps_idevalloc_.handle);
rhi_device_->dealloc_memory(runtime_idevalloc_.handle);
rhi_device_->dealloc_memory(print_assert_idevalloc_.handle);
}

void add_compiled_snode_tree(const CompiledStructs &compiled_tree) {
Expand Down Expand Up @@ -679,8 +676,8 @@ class KernelManager::Impl {
}
input_buffers[BufferDescriptor::global_tmps()] =
global_tmps_idevalloc_.buffer;
input_buffers[BufferDescriptor::runtime()] = runtime_buffer_.get();
input_buffers[BufferDescriptor::print()] = print_buffer_.get();
input_buffers[BufferDescriptor::runtime()] = runtime_idevalloc_.buffer;
input_buffers[BufferDescriptor::print()] = print_assert_idevalloc_.buffer;

if (ctx_blitter) {
ctx_blitter->host_to_metal();
Expand All @@ -703,7 +700,7 @@ class KernelManager::Impl {
}
if (used_print_assert) {
clear_print_assert_buffer();
buffers_to_blit.push_back(print_buffer_.get());
buffers_to_blit.push_back(print_assert_idevalloc_.buffer);
}
blit_buffers_and_sync(buffers_to_blit);

Expand Down Expand Up @@ -735,7 +732,7 @@ class KernelManager::Impl {
std::size_t get_snode_num_dynamically_allocated(SNode *snode) {
// TODO(k-ye): Have a generic way for querying these sparse runtime stats.
mac::ScopedAutoreleasePool pool;
blit_buffers_and_sync({runtime_buffer_.get()});
blit_buffers_and_sync({runtime_idevalloc_.buffer});
auto *sna = dev_runtime_mirror_.snode_allocators + snode->id;
// WHY -1?
//
Expand All @@ -752,7 +749,7 @@ class KernelManager::Impl {
private:
void init_runtime_buffer(const CompiledRuntimeModule &rtm_module,
int random_seed) {
char *addr = runtime_mem_->ptr();
char *addr = runtime_idevalloc_.mem->ptr();
// init rand_seeds
std::default_random_engine generator((unsigned int)random_seed);
std::uniform_int_distribution<uint32_t> distr(
Expand All @@ -765,7 +762,7 @@ class KernelManager::Impl {
TI_DEBUG("Initialized random seeds size={}", rtm_module.rand_seeds_size);

using namespace shaders;
addr = runtime_mem_->ptr() + rtm_module.rand_seeds_size;
addr = runtime_idevalloc_.mem->ptr() + rtm_module.rand_seeds_size;
const char *const addr_begin = addr;
dev_runtime_mirror_.snode_metas = (SNodeMeta *)addr;
size_t addr_offset = sizeof(SNodeMeta) * kMaxNumSNodes;
Expand Down Expand Up @@ -963,14 +960,14 @@ class KernelManager::Impl {
}

void mark_runtime_buffer_modified() {
did_modify_range(runtime_buffer_.get(), /*location=*/0,
runtime_mem_->size());
did_modify_range(runtime_idevalloc_.buffer, /*location=*/0,
runtime_idevalloc_.mem->size());
}

void clear_print_assert_buffer() {
const auto sz = print_mem_->size();
std::memset(print_mem_->ptr(), 0, sz);
did_modify_range(print_buffer_.get(), /*location=*/0, sz);
const auto sz = print_assert_idevalloc_.mem->size();
std::memset(print_assert_idevalloc_.mem->ptr(), 0, sz);
did_modify_range(print_assert_idevalloc_.buffer, /*location=*/0, sz);
}

void blit_buffers_and_sync(
Expand Down Expand Up @@ -1002,8 +999,8 @@ class KernelManager::Impl {
void check_assertion_failure() {
// TODO: Copy this to program's result_buffer, and let the Taichi runtime
// handle the assertion failures uniformly.
auto *asst_rec =
reinterpret_cast<shaders::AssertRecorderData *>(print_mem_->ptr());
auto *asst_rec = reinterpret_cast<shaders::AssertRecorderData *>(
print_assert_idevalloc_.mem->ptr());
if (!asst_rec->flag) {
return;
}
Expand All @@ -1030,8 +1027,9 @@ class KernelManager::Impl {
//
// check_assertion_failure(); <-- Code below is skipped...
// flush_print_buffers();
// memset(print_mem_->ptr(), 0, print_mem_->size());
// did_modify_range(print_buffer_);
// memset(print_assert_idevalloc_.mem->ptr(), 0,
// print_assert_idevalloc_.mem->size());
// did_modify_range(print_assert_idevalloc_.buffer);
//
// As a workaround, we put [didModifyRange:] before sync, where the program
// is still executing normally.
Expand All @@ -1041,7 +1039,7 @@ class KernelManager::Impl {

void flush_print_buffers() {
auto *pa = reinterpret_cast<shaders::PrintMsgAllocator *>(
print_mem_->ptr() + shaders::kMetalAssertBufferSize);
print_assert_idevalloc_.mem->ptr() + shaders::kMetalAssertBufferSize);
const int used_sz =
std::min(pa->next, shaders::kMetalPrintMsgsMaxQueueSize);
TI_TRACE("Print buffer used bytes: {}", used_sz);
Expand Down Expand Up @@ -1142,12 +1140,9 @@ class KernelManager::Impl {
std::size_t command_buffer_id_{0};
std::vector<SNodesRootBuffer> root_buffers_;
DevAllocWithInternals global_tmps_idevalloc_;
std::unique_ptr<BufferMemoryView> runtime_mem_{nullptr};
nsobj_unique_ptr<MTLBuffer> runtime_buffer_{nullptr};
DevAllocWithInternals runtime_idevalloc_;
DevAllocWithInternals print_assert_idevalloc_;
int last_snode_id_used_in_runtime_{-1};
// TODO: Rename these to 'print_assert_{mem|buffer}_'
std::unique_ptr<BufferMemoryView> print_mem_{nullptr};
nsobj_unique_ptr<MTLBuffer> print_buffer_{nullptr};
std::unordered_map<std::string, std::unique_ptr<CompiledTaichiKernel>>
compiled_taichi_kernels_;
PrintStringTable print_strtable_;
Expand Down

0 comments on commit 5d086dc

Please sign in to comment.