From 6e68872c69ce1a4619a8648d93d74d312064ee25 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Thu, 22 Oct 2020 04:54:27 +0900 Subject: [PATCH] [metal] Create helper methods for TLS codegen (#1982) --- taichi/backends/metal/codegen_metal.cpp | 59 +++++++++++-------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 20436f868cd5d..3836aa0132b3a 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -931,17 +931,9 @@ class KernelCodegen : public IRVisitor { // end_ = total_elems + begin_expr emit("const int end_ = {} + {};", total_elems_name, begin_expr); + emit_runtime_and_memalloc_def(); if (used_tls) { - // Using TLS means we will access some SNodes within this kernel. The - // struct of an SNode needs Runtime and MemoryAllocator to construct. - emit_runtime_and_memalloc_def(); - // Using |int32_t| because it aligns to 4bytes. - emit("// TLS prologue"); - const std::string tls_bufi32_name = "tls_bufi32_"; - emit("int32_t {}[{}];", tls_bufi32_name, (stmt->tls_size + 3) / 4); - emit("thread char* {} = reinterpret_cast({});", - kTlsBufferName, tls_bufi32_name); - stmt->tls_prologue->accept(this); + generate_tls_prologue(stmt); } emit("for (int ii = begin_; ii < end_; ii += {}) {{", kKernelGridSizeName); @@ -964,12 +956,7 @@ class KernelCodegen : public IRVisitor { emit("}}"); // closes for loop if (used_tls) { - TI_ASSERT(stmt->tls_epilogue != nullptr); - inside_tls_epilogue_ = true; - emit("{{ // TLS epilogue"); - stmt->tls_epilogue->accept(this); - inside_tls_epilogue_ = false; - emit("}}"); + generate_tls_epilogue(stmt); } current_appender().pop_indent(); @@ -1010,17 +997,7 @@ class KernelCodegen : public IRVisitor { emit_runtime_and_memalloc_def(); if (used_tls) { - // Using TLS means we will access some SNodes within this kernel. The - // struct of an SNode needs Runtime and MemoryAllocator to construct. - // Using |int32_t| because it aligns to 4bytes. - // - // TODO(k-ye): De-dupe TLS for range-for and struct-for. - emit("// TLS prologue"); - const std::string tls_bufi32_name = "tls_bufi32_"; - emit("int32_t {}[{}];", tls_bufi32_name, (stmt->tls_size + 3) / 4); - emit("thread char* {} = reinterpret_cast({});", - kTlsBufferName, tls_bufi32_name); - stmt->tls_prologue->accept(this); + generate_tls_prologue(stmt); } emit("ListManager parent_list;"); @@ -1079,19 +1056,33 @@ class KernelCodegen : public IRVisitor { current_appender().pop_indent(); if (used_tls) { - // TODO(k-ye): De-dupe TLS for range-for and struct-for. - TI_ASSERT(stmt->tls_epilogue != nullptr); - inside_tls_epilogue_ = true; - emit("{{ // TLS epilogue"); - stmt->tls_epilogue->accept(this); - inside_tls_epilogue_ = false; - emit("}}"); + generate_tls_epilogue(stmt); } emit("}}\n"); // closes kernel mtl_kernels_attribs()->push_back(ka); } + void generate_tls_prologue(OffloadedStmt *stmt) { + TI_ASSERT(stmt->tls_prologue != nullptr); + emit("// TLS prologue"); + const std::string tls_bufi32_name = "tls_bufi32_"; + // Using |int32_t| because it aligns to 4bytes. + emit("int32_t {}[{}];", tls_bufi32_name, (stmt->tls_size + 3) / 4); + emit("thread char* {} = reinterpret_cast({});", + kTlsBufferName, tls_bufi32_name); + stmt->tls_prologue->accept(this); + } + + void generate_tls_epilogue(OffloadedStmt *stmt) { + TI_ASSERT(stmt->tls_epilogue != nullptr); + inside_tls_epilogue_ = true; + emit("{{ // TLS epilogue"); + stmt->tls_epilogue->accept(this); + inside_tls_epilogue_ = false; + emit("}}"); + } + void add_runtime_list_op_kernel(OffloadedStmt *stmt, const std::string &kernel_name) { using Type = OffloadedStmt::TaskType;