diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 506dd8a73c9fb..733b61d3ca337 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -69,6 +69,14 @@ std::string buffer_to_name(BuffersEnum b) { return {}; } +bool is_ret_type_bit_pointer(Stmt *s) { + if (auto *ty = s->ret_type->cast()) { + // Don't use as() directly, it would fail when we inject a global tmp. + return ty->is_bit_pointer(); + } + return false; +} + class KernelCodegen : public IRVisitor { private: enum class Section { @@ -128,10 +136,12 @@ class KernelCodegen : public IRVisitor { generate_kernels(); std::string source_code; - for (const auto s : kAllSections) { - source_code += section_appenders_.find(s)->second.lines(); - source_code += '\n'; - } + source_code += section_appenders_.at(Section::Headers).lines(); + source_code += "namespace {\n"; + source_code += section_appenders_.at(Section::Structs).lines(); + source_code += section_appenders_.at(Section::KernelFuncs).lines(); + source_code += "} // namespace\n"; + source_code += section_appenders_.at(Section::Kernels).lines(); return source_code; } @@ -189,21 +199,6 @@ class KernelCodegen : public IRVisitor { kRootBufferName); } - void visit(GetChStmt *stmt) override { - // E.g. `parent.get*(runtime, mem_alloc)` - const auto get_call = - fmt::format("{}.get{}({}, {})", stmt->input_ptr->raw_name(), stmt->chid, - kRuntimeVarName, kMemAllocVarName); - if (stmt->output_snode->is_place()) { - emit(R"(device {}* {} = {}.val;)", - metal_data_type_name(stmt->output_snode->dt), stmt->raw_name(), - get_call); - } else { - emit(R"({} {} = {};)", stmt->output_snode->node_type_name, - stmt->raw_name(), get_call); - } - } - void visit(LinearizeStmt *stmt) override { std::string val = "0"; for (int i = 0; i < (int)stmt->inputs.size(); i++) { @@ -229,8 +224,29 @@ class KernelCodegen : public IRVisitor { } const auto *sn = stmt->snode; const auto snty = sn->type; + if (snty == SNodeType::bit_struct) { + // Example *bit_struct* struct generated on Metal: + // + // struct Sx { + // // bit_struct + // Sx(device byte *b, ...) : base(b) {} + // device byte *base; + // }; + emit("auto {} = {}.base;", stmt->raw_name(), parent); + return; + } const std::string index_name = stmt->input_index->raw_name(); - + // Example SNode struct generated on Metal: + // + // struct S1 { + // // dense + // S1(device byte *addr, ...) { rep_.init(addr); } + // S1_ch children(int i) { return {rep_.addr() + (i * elem_stride)}; } + // inline void activate(int i) { rep_.activate(i); } + // ... + // private: + // SNodeRep_dense rep_; + // }; if (stmt->activate) { TI_ASSERT(is_supported_sparse_type(snty)); emit("{}.activate({});", parent, index_name); @@ -239,6 +255,32 @@ class KernelCodegen : public IRVisitor { parent, index_name); } + void visit(GetChStmt *stmt) override { + auto *in_snode = stmt->input_snode; + auto *out_snode = stmt->output_snode; + if (in_snode->type == SNodeType::bit_struct) { + TI_ASSERT(stmt->ret_type->as()->is_bit_pointer()); + const auto *bit_struct_ty = in_snode->dt->cast(); + const auto bit_offset = + bit_struct_ty->get_member_bit_offset(in_snode->child_id(out_snode)); + // stmt->input_ptr is the "base" member in the generated SNode struct. + emit("SNodeBitPointer {}({}, /*offset=*/{});", stmt->raw_name(), + stmt->input_ptr->raw_name(), bit_offset); + return; + } + // E.g. `parent.get*(runtime, mem_alloc)` + const auto get_call = + fmt::format("{}.get{}({}, {})", stmt->input_ptr->raw_name(), stmt->chid, + kRuntimeVarName, kMemAllocVarName); + if (out_snode->is_place()) { + emit(R"(device {}* {} = {}.val;)", metal_data_type_name(out_snode->dt), + stmt->raw_name(), get_call); + } else { + emit(R"({} {} = {};)", out_snode->node_type_name, stmt->raw_name(), + get_call); + } + } + void visit(SNodeOpStmt *stmt) override { const std::string result_var = stmt->raw_name(); const auto opty = stmt->op_type; @@ -292,13 +334,23 @@ class KernelCodegen : public IRVisitor { void visit(GlobalStoreStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - emit(R"(*{} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name()); + + if (!is_ret_type_bit_pointer(stmt->ptr)) { + emit(R"(*{} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name()); + return; + } + handle_bit_pointer_global_store(stmt); } void visit(GlobalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - emit(R"({} {} = *{};)", metal_data_type_name(stmt->element_type()), - stmt->raw_name(), stmt->ptr->raw_name()); + std::string rhs_expr; + if (!is_ret_type_bit_pointer(stmt->ptr)) { + rhs_expr = fmt::format("*{}", stmt->ptr->raw_name()); + } else { + rhs_expr = construct_bit_pointer_global_load(stmt); + } + emit("const auto {} = {};", stmt->raw_name(), rhs_expr); } void visit(ArgLoadStmt *stmt) override { @@ -457,7 +509,6 @@ class KernelCodegen : public IRVisitor { void visit(AtomicOpStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - const auto dt = stmt->val->element_type(); const auto op_type = stmt->op_type; std::string op_name; bool handle_float = false; @@ -475,6 +526,11 @@ class KernelCodegen : public IRVisitor { TI_NOT_IMPLEMENTED; } + if (is_ret_type_bit_pointer(stmt->dest)) { + handle_bit_pointer_atomics(stmt); + return; + } + std::string val_var = stmt->val->raw_name(); // TODO(k-ye): This is not a very reliable way to detect if we're in TLS // xlogues... @@ -488,7 +544,7 @@ class KernelCodegen : public IRVisitor { emit("if ({} == 0) {{", kKernelTidInSimdgroupName); current_appender().push_indent(); } - + const auto dt = stmt->val->element_type(); if (dt->is_primitive(PrimitiveTypeID::i32)) { emit( "const auto {} = atomic_fetch_{}_explicit((device atomic_int*){}, " @@ -626,9 +682,11 @@ class KernelCodegen : public IRVisitor { if (std::holds_alternative(entry)) { auto *arg_stmt = std::get(entry); const auto dt = arg_stmt->element_type(); - TI_ASSERT_INFO(dt->is_primitive(PrimitiveTypeID::i32) || - dt->is_primitive(PrimitiveTypeID::f32), - "print() only supports i32 or f32 scalars for now."); + TI_ASSERT_INFO( + dt->is_primitive(PrimitiveTypeID::i32) || + dt->is_primitive(PrimitiveTypeID::u32) || + dt->is_primitive(PrimitiveTypeID::f32), + "print() only supports i32, u32 or f32 scalars for now."); emit("{}.pm_set_{}({}, {});", msg_var_name, data_type_name(dt), i, arg_stmt->raw_name()); } else { @@ -773,6 +831,133 @@ class KernelCodegen : public IRVisitor { emit_kernel_args_struct(); } + void handle_bit_pointer_global_store(GlobalStoreStmt *stmt) { + auto *ptr_type = stmt->ptr->ret_type->as(); + TI_ASSERT(ptr_type->is_bit_pointer()); + auto *pointee_type = ptr_type->get_pointee_type(); + CustomIntType *cit = nullptr; + std::string store_value_expr; + if (auto *cit_cast = pointee_type->cast()) { + cit = cit_cast; + store_value_expr = stmt->data->raw_name(); + } else if (auto *cft = pointee_type->cast()) { + validate_cft_for_metal(cft); + auto *digits_cit = cft->get_digits_type()->as(); + cit = digits_cit; + store_value_expr = construct_float_to_custom_int_expr( + stmt->data, cft->get_scale(), digits_cit); + } else { + TI_NOT_IMPLEMENTED; + } + // Type of |stmt->ptr| is SNodeBitPointer + const auto num_bits = cit->get_num_bits(); + if (is_full_bits(num_bits)) { + emit("mtl_set_full_bits({}, {});", stmt->ptr->raw_name(), + store_value_expr); + } else { + emit("mtl_set_partial_bits({},", stmt->ptr->raw_name()); + emit(" {},", store_value_expr); + emit(" /*bits=*/{});", num_bits); + } + } + + // Returns the expression of the load result + std::string construct_bit_pointer_global_load(GlobalLoadStmt *stmt) const { + auto *ptr_type = stmt->ptr->ret_type->as(); + TI_ASSERT(ptr_type->is_bit_pointer()); + auto *pointee_type = ptr_type->get_pointee_type(); + if (auto *cit = pointee_type->cast()) { + return construct_load_as_custom_int(stmt->ptr, cit); + } else if (auto *cft = pointee_type->cast()) { + validate_cft_for_metal(cft); + const auto loaded = construct_load_as_custom_int( + stmt->ptr, cft->get_digits_type()->as()); + // Computes `float(digits_expr) * scale` + // See LLVM backend's reconstruct_custom_float() + return fmt::format("(static_cast({}) * {})", loaded, + cft->get_scale()); + } + TI_NOT_IMPLEMENTED; + return ""; + } + + void handle_bit_pointer_atomics(AtomicOpStmt *stmt) { + TI_ERROR_IF(stmt->op_type != AtomicOpType::add, + "Only atomic add is supported for bit pointer types"); + // Type of |dest_ptr| is SNodeBitPointer + const auto *dest_ptr = stmt->dest; + auto *ptr_type = dest_ptr->ret_type->as(); + TI_ASSERT(ptr_type->is_bit_pointer()); + auto *pointee_type = ptr_type->get_pointee_type(); + CustomIntType *cit = nullptr; + std::string val_expr; + if (auto *cit_cast = pointee_type->cast()) { + cit = cit_cast; + val_expr = stmt->val->raw_name(); + } else if (auto *cft = pointee_type->cast()) { + cit = cft->get_digits_type()->as(); + val_expr = + construct_float_to_custom_int_expr(stmt->val, cft->get_scale(), cit); + } else { + TI_NOT_IMPLEMENTED; + } + const auto num_bits = cit->get_num_bits(); + if (is_full_bits(num_bits)) { + emit("const auto {} = mtl_atomic_add_full_bits({}, {});", + stmt->raw_name(), dest_ptr->raw_name(), val_expr); + } else { + emit("const auto {} = mtl_atomic_add_partial_bits({},", stmt->raw_name(), + dest_ptr->raw_name()); + emit(" {},", val_expr); + emit(" /*bits=*/{});", num_bits); + } + } + + // Returns the expression of `int(val_stmt * (1.0f / scale) + 0.5f)` + std::string construct_float_to_custom_int_expr( + const Stmt *val_stmt, + float64 scale, + CustomIntType *digits_cit) const { + DataType compute_dt(digits_cit->get_compute_type()->as()); + // This implicitly casts double to float on the host. + const float inv_scale = 1.0 / scale; + // Creating an expression (instead of holding intermediate results with + // variables) because |val_stmt| could be used multiple times. If the + // intermediate variables are named based on |val_stmt|, it would result in + // symbol redefinitions. + return fmt::format("mtl_float_to_custom_int<{}>(/*inv_scale=*/{} * {})", + metal_data_type_name(compute_dt), inv_scale, + val_stmt->raw_name()); + } + + // Returns expression of the loaded integer. + std::string construct_load_as_custom_int(const Stmt *bit_ptr_stmt, + CustomIntType *cit) const { + DataType compute_dt(cit->get_compute_type()->as()); + const auto num_bits = cit->get_num_bits(); + if (is_full_bits(num_bits)) { + return fmt::format("mtl_get_full_bits<{}>({})", + metal_data_type_name(compute_dt), + bit_ptr_stmt->raw_name()); + } + return fmt::format("mtl_get_partial_bits<{}>({}, {})", + metal_data_type_name(compute_dt), + bit_ptr_stmt->raw_name(), num_bits); + } + + void validate_cft_for_metal(CustomFloatType *cft) const { + if (cft->get_exponent_type() != nullptr) { + TI_NOT_IMPLEMENTED; + } + if (cft->get_compute_type()->as() != PrimitiveType::f32) { + TI_ERROR("Metal only supports 32-bit float"); + } + } + + static bool is_full_bits(int bits) { + return bits == (sizeof(uint32_t) * 8); + } + void emit_kernel_args_struct() { if (ctx_attribs_.empty()) { return; @@ -924,6 +1109,7 @@ class KernelCodegen : public IRVisitor { emit("const int {} = {} - {};", total_elems_name, end_expr, begin_expr); ka.advisory_total_num_threads = kMaxNumThreadsGridStrideLoop; } + // TODO: I've seen cases where |block_dim| was set to 1... ka.advisory_num_threads_per_group = stmt->block_dim; // begin_ = thread_id + begin_expr emit("const int begin_ = {} + {};", kKernelThreadIdName, begin_expr); diff --git a/taichi/backends/metal/kernel_manager.cpp b/taichi/backends/metal/kernel_manager.cpp index 050ea9bc20f33..56dff3d562379 100644 --- a/taichi/backends/metal/kernel_manager.cpp +++ b/taichi/backends/metal/kernel_manager.cpp @@ -10,14 +10,15 @@ #include "taichi/backends/metal/constants.h" #include "taichi/inc/constants.h" #include "taichi/math/arithmetic.h" -#include "taichi/util/action_recorder.h" #include "taichi/python/print_buffer.h" +#include "taichi/util/action_recorder.h" #include "taichi/util/file_sequence_writer.h" #include "taichi/util/str.h" #ifdef TI_PLATFORM_OSX #include #include + #include #include "taichi/backends/metal/api.h" @@ -137,13 +138,27 @@ class CompiledMtlKernelBase { const auto tgs = get_thread_grid_settings( num_threads, kernel_attribs_.advisory_num_threads_per_group); if (!is_jit_evalutor_) { - ActionRecorder::get_instance().record( - "launch_kernel", - {ActionArg("kernel_name", kernel_attribs_.name), - ActionArg("num_threadgroups", tgs.num_threadgroups), - ActionArg("num_threads_per_group", tgs.num_threads_per_group)}); + const auto tt = kernel_attribs_.task_type; + std::vector record_args = { + ActionArg("mtl_kernel_name", kernel_attribs_.name), + ActionArg("advisory_num_threads", num_threads), + ActionArg("num_threadgroups", tgs.num_threadgroups), + ActionArg("num_threads_per_group", tgs.num_threads_per_group), + ActionArg("task_type", offloaded_task_type_name(tt)), + }; + const auto &buffers = kernel_attribs_.buffers; + for (int i = 0; i < buffers.size(); ++i) { + record_args.push_back( + ActionArg(fmt::format("mtl_buffer_{}", i), + KernelAttributes::buffers_name(buffers[i]))); + } + ActionRecorder::get_instance().record("launch_kernel", + std::move(record_args)); } - + TI_TRACE( + "Dispatching Metal kernel {}, num_threadgroups={} " + "num_threads_per_group={}", + kernel_attribs_.name, tgs.num_threadgroups, tgs.num_threads_per_group); dispatch_threadgroups(encoder.get(), tgs.num_threadgroups, tgs.num_threads_per_group); end_encoding(encoder.get()); @@ -302,7 +317,7 @@ class CompiledTaichiKernel { auto fn = writer.write(params.mtl_source_code); ActionRecorder::get_instance().record( "save_kernel", - {ActionArg("kernel_name", std::string(ti_kernel_attribs.name)), + {ActionArg("ti_kernel_name", std::string(ti_kernel_attribs.name)), ActionArg("filename", fn)}); } for (const auto &ka : ti_kernel_attribs.mtl_kernels_attribs) { @@ -342,7 +357,7 @@ class CompiledTaichiKernel { if (!ti_kernel_attribs.is_jit_evaluator) { ActionRecorder::get_instance().record( "allocate_context_buffer", - {ActionArg("kernel_name", std::string(ti_kernel_attribs.name)), + {ActionArg("ti_kernel_name", std::string(ti_kernel_attribs.name)), ActionArg("size_in_bytes", (int64)ctx_attribs.total_bytes())}); } ctx_buffer = @@ -394,7 +409,7 @@ class HostMetalCtxBlitter { if (!ti_kernel_attribs_->is_jit_evaluator) { ActionRecorder::get_instance().record( "context_host_to_metal", - {ActionArg("kernel_name", kernel_name_), ActionArg("arg_id", i), + {ActionArg("ti_kernel_name", kernel_name_), ActionArg("arg_id", i), ActionArg("offset_in_bytes", (int64)arg.offset_in_mem)}); } if (arg.is_array) { @@ -447,8 +462,9 @@ class HostMetalCtxBlitter { ActionRecorder::get_instance().record( "context_metal_to_host", { - ActionArg("kernel_name", kernel_name_), + ActionArg("ti_kernel_name", kernel_name_), ActionArg("arg_id", i), + ActionArg("arg_type", "ptr"), ActionArg("size_in_bytes", (int64)arg.stride), ActionArg("host_address", fmt::format("0x{:x}", (uint64)host_ptr)), @@ -713,6 +729,9 @@ class KernelManager::Impl { case SNodeType::pointer: rtm_meta->type = SNodeMeta::Pointer; break; + case SNodeType::bit_struct: + rtm_meta->type = SNodeMeta::BitStruct; + break; default: TI_ERROR("Unsupported SNode type={}", snode_type_name(sn_meta.snode->type)); @@ -720,9 +739,9 @@ class KernelManager::Impl { } TI_DEBUG( "SnodeMeta\n id={}\n type={}\n element_stride={}\n " - "num_slots={}\n", + "num_slots={}\n mem_offset_in_parent={}\n", i, snode_type_name(sn_meta.snode->type), rtm_meta->element_stride, - rtm_meta->num_slots); + rtm_meta->num_slots, rtm_meta->mem_offset_in_parent); } size_t addr_offset = sizeof(SNodeMeta) * max_snodes; addr += addr_offset; @@ -786,6 +805,7 @@ class KernelManager::Impl { std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count()); + const auto rand_seeds_begin = (addr - addr_begin); std::uniform_int_distribution distr( 0, std::numeric_limits::max()); for (int i = 0; i < kNumRandSeeds; ++i) { @@ -793,9 +813,14 @@ class KernelManager::Impl { *s = distr(generator); addr += sizeof(uint32_t); } - TI_DEBUG("Initialized random seeds, size={} accumuated={}", - kNumRandSeeds * sizeof(uint32_t), (addr - addr_begin)); - + TI_DEBUG("Initialized random seeds, begin={} size={} accumuated={}", + rand_seeds_begin, kNumRandSeeds * sizeof(uint32_t), + (addr - addr_begin)); + ActionRecorder::get_instance().record( + "initialize_runtime_buffer", + { + ActionArg("rand_seeds_begin", (int64)rand_seeds_begin), + }); if (compiled_structs_.need_snode_lists_data) { auto *mem_alloc = reinterpret_cast(addr); // Make sure the retured memory address is always greater than 1. @@ -904,6 +929,8 @@ class KernelManager::Impl { const int32_t x = msg.pm_get_data(i); if (dt == MsgType::I32) { py_cout << x; + } else if (dt == MsgType::U32) { + py_cout << static_cast(x); } else if (dt == MsgType::F32) { py_cout << *reinterpret_cast(&x); } else if (dt == MsgType::Str) { diff --git a/taichi/backends/metal/shaders/atomic_stubs.h b/taichi/backends/metal/shaders/atomic_stubs.h index 77bbe6ffc9a84..8b6851929c94e 100644 --- a/taichi/backends/metal/shaders/atomic_stubs.h +++ b/taichi/backends/metal/shaders/atomic_stubs.h @@ -3,6 +3,9 @@ using atomic_int = int; using atomic_uint = unsigned int; +template +struct _atomic {}; + namespace metal { using memory_order = bool; diff --git a/taichi/backends/metal/shaders/print.metal.h b/taichi/backends/metal/shaders/print.metal.h index 086494c1ba282..48edeabcb4c55 100644 --- a/taichi/backends/metal/shaders/print.metal.h +++ b/taichi/backends/metal/shaders/print.metal.h @@ -60,7 +60,7 @@ STR( // * Followed by N i32s, one for each print arg. F32 are encoded to I32. // For strings, there is a string table on the host side, so that the // kernel only needs to store a I32 string ID. - enum Type { I32 = 1, F32 = 2, Str = 3 }; + enum Type { I32 = 1, U32 = 2, F32 = 2, Str = 3 }; PrintMsg(device int32_t *buf, int num_entries) : mask_buf_(buf), @@ -71,6 +71,12 @@ STR( set_entry(i, x, Type::I32); } + void pm_set_u32(int i, uint x) { + // https://stackoverflow.com/a/21769421/12003165 + const int32_t ix = static_cast(x); + set_entry(i, ix, Type::U32); + } + void pm_set_f32(int i, float x) { const int32_t ix = *reinterpret_cast(&x); set_entry(i, ix, Type::F32); diff --git a/taichi/backends/metal/shaders/prolog.h b/taichi/backends/metal/shaders/prolog.h index 56b5c5c39b95c..2db8309aa3014 100644 --- a/taichi/backends/metal/shaders/prolog.h +++ b/taichi/backends/metal/shaders/prolog.h @@ -22,7 +22,7 @@ #define thread #define kernel -#define byte char +using byte = char; #include "taichi/backends/metal/shaders/atomic_stubs.h" diff --git a/taichi/backends/metal/shaders/runtime_structs.metal.h b/taichi/backends/metal/shaders/runtime_structs.metal.h index bd4d555bdbdb9..171e8463f4bf3 100644 --- a/taichi/backends/metal/shaders/runtime_structs.metal.h +++ b/taichi/backends/metal/shaders/runtime_structs.metal.h @@ -28,7 +28,6 @@ static_assert(sizeof(char *) == 8, "Metal pointers are 64-bit."); // clang-format off METAL_BEGIN_RUNTIME_STRUCTS_DEF STR( - // clang-format on constant constexpr int kTaichiMaxNumIndices = 8; constant constexpr int kTaichiNumChunks = 1024; constant constexpr int kAlignment = 8; @@ -105,6 +104,7 @@ STR( Bitmasked = 2, Dynamic = 3, Pointer = 4, + BitStruct = 5, }; int32_t element_stride = 0; int32_t num_slots = 0; @@ -149,7 +149,6 @@ STR( return belonged_nodemgr.id < 0; } }; - // clang-format off ) METAL_END_RUNTIME_STRUCTS_DEF // clang-format on diff --git a/taichi/backends/metal/shaders/snode_bit_pointer.metal.h b/taichi/backends/metal/shaders/snode_bit_pointer.metal.h new file mode 100644 index 0000000000000..5e29603b39b9a --- /dev/null +++ b/taichi/backends/metal/shaders/snode_bit_pointer.metal.h @@ -0,0 +1,145 @@ +#include "taichi/backends/metal/shaders/prolog.h" + +#ifdef TI_INSIDE_METAL_CODEGEN + +#ifndef TI_METAL_NESTED_INCLUDE +#define METAL_BEGIN_SRC_DEF constexpr auto kMetalSNodeBitPointerSourceCode = +#define METAL_END_SRC_DEF ; +#else +#define METAL_BEGIN_SRC_DEF +#define METAL_END_SRC_DEF +#endif // TI_METAL_NESTED_INCLUDE + +#else + +#define METAL_BEGIN_SRC_DEF +#define METAL_END_SRC_DEF + +#include + +using std::is_same; +using std::is_signed; + +#endif // TI_INSIDE_METAL_CODEGEN + +METAL_BEGIN_SRC_DEF +STR( + // SNodeBitPointer is used as the value type for bit_struct SNodes on Metal. + struct SNodeBitPointer { + // Physical type is hardcoded to uint32_t. This is a restriction because + // Metal only supports 32-bit int/uint atomics. + device uint32_t *base; + uint32_t offset; + + SNodeBitPointer(device byte * b, uint32_t o) + : base((device uint32_t *)b), offset(o) { + } + }; + + // |f| should already be scaled. |C| is the compute type. + template C mtl_float_to_custom_int(float f) { + // Branch free implementation of `f + sign(f) * 0.5`. + // See rounding_prepare_f* in taichi/runtime/llvm/runtime.cpp + const int32_t delta_bits = + (union_cast(f) & 0x80000000) | union_cast(0.5f); + const float delta = union_cast(delta_bits); + return static_cast(f + delta); + } + + void mtl_set_partial_bits(SNodeBitPointer bp, + uint32_t value, + uint32_t bits) { + // See taichi/runtime/llvm/runtime.cpp + // + // We could have encoded |bits| as a compile time constant, but I guess + // the performance improvement is negligible. + using P = uint32_t; // (P)hysical type + constexpr int N = sizeof(P) * 8; + // precondition: |mask| & |value| == |value| + const uint32_t mask = + ((~(uint32_t)0U) << (N - bits)) >> (N - bp.offset - bits); + device auto *atm_ptr = reinterpret_cast(bp.base); + bool ok = false; + while (!ok) { + P old_val = *(bp.base); + P new_val = (old_val & (~mask)) | (value << bp.offset); + ok = atomic_compare_exchange_weak_explicit(atm_ptr, &old_val, new_val, + metal::memory_order_relaxed, + metal::memory_order_relaxed); + } + } + + void mtl_set_full_bits(SNodeBitPointer bp, uint32_t value) { + device auto *atm_ptr = reinterpret_cast(bp.base); + atomic_store_explicit(atm_ptr, value, metal::memory_order_relaxed); + } + + uint32_t mtl_atomic_add_partial_bits(SNodeBitPointer bp, + uint32_t value, + uint32_t bits) { + // See taichi/runtime/llvm/runtime.cpp + using P = uint32_t; // (P)hysical type + constexpr int N = sizeof(P) * 8; + // precondition: |mask| & |value| == |value| + const uint32_t mask = + ((~(uint32_t)0U) << (N - bits)) >> (N - bp.offset - bits); + device auto *atm_ptr = reinterpret_cast(bp.base); + P old_val = 0; + bool ok = false; + while (!ok) { + old_val = *(bp.base); + P new_val = old_val + (value << bp.offset); + // The above computation might overflow |bits|, so we have to OR them + // again, with the mask applied. + new_val = (old_val & (~mask)) | (new_val & mask); + ok = atomic_compare_exchange_weak_explicit(atm_ptr, &old_val, new_val, + metal::memory_order_relaxed, + metal::memory_order_relaxed); + } + return old_val; + } + + uint32_t mtl_atomic_add_full_bits(SNodeBitPointer bp, uint32_t value) { + // When all the bits are used, we can replace CAS with a simple add. + device auto *atm_ptr = reinterpret_cast(bp.base); + return atomic_fetch_add_explicit(atm_ptr, value, + metal::memory_order_relaxed); + } + + namespace detail { + // Metal supports C++ template specialization... what a crazy world + template + struct SHRSelector { + using type = int32_t; + }; + + template <> + struct SHRSelector { + using type = uint32_t; + }; + } // namespace detail + + // (C)ompute type + template + C mtl_get_partial_bits(SNodeBitPointer bp, uint32_t bits) { + using P = uint32_t; // (P)hysical type + constexpr int N = sizeof(P) * 8; + const P phy_val = *(bp.base); + // Use CSel instead of C to preserve the bit width. + using CSel = typename detail::SHRSelector::value>::type; + // SHL is identical between signed and unsigned integrals. + const auto step1 = static_cast(phy_val << (N - (bp.offset + bits))); + // ASHR vs LSHR is implicitly encoded in type CSel. + return static_cast(step1 >> (N - bits)); + } + + template C mtl_get_full_bits(SNodeBitPointer bp) { + return static_cast(*(bp.base)); + }) +METAL_END_SRC_DEF +// clang-format on + +#undef METAL_BEGIN_SRC_DEF +#undef METAL_END_SRC_DEF + +#include "taichi/backends/metal/shaders/epilog.h" diff --git a/taichi/backends/metal/struct_metal.cpp b/taichi/backends/metal/struct_metal.cpp index 8fc536cfa4295..f6cba791bed92 100644 --- a/taichi/backends/metal/struct_metal.cpp +++ b/taichi/backends/metal/struct_metal.cpp @@ -20,6 +20,7 @@ namespace shaders { #define TI_INSIDE_METAL_CODEGEN #include "taichi/backends/metal/shaders/runtime_structs.metal.h" #include "taichi/backends/metal/shaders/runtime_utils.metal.h" +#include "taichi/backends/metal/shaders/snode_bit_pointer.metal.h" #undef TI_INSIDE_METAL_CODEGEN #include "taichi/backends/metal/shaders/runtime_structs.metal.h" @@ -63,17 +64,25 @@ class StructCompiler { { max_snodes_ = 0; has_sparse_snode_ = false; +#define CHECK_UNSUPPORTED_TYPE(type_case) \ + else if (ty == SNodeType::type_case) { \ + TI_ERROR("Metal backend does not support SNode=" #type_case " yet"); \ + } for (const auto &sn : snodes_) { const auto ty = sn->type; - if (ty == SNodeType::root || ty == SNodeType::dense || - ty == SNodeType::bitmasked || ty == SNodeType::dynamic || - ty == SNodeType::pointer) { + if (ty == SNodeType::place) { + // do nothing + } + CHECK_UNSUPPORTED_TYPE(bit_array) + CHECK_UNSUPPORTED_TYPE(hash) + else { max_snodes_ = std::max(max_snodes_, sn->id); } has_sparse_snode_ = has_sparse_snode_ || is_supported_sparse_type(ty); } ++max_snodes_; } +#undef CHECK_UNSUPPORTED_TYPE CompiledStructs result; result.root_size = compute_snode_size(&root); @@ -123,7 +132,7 @@ class StructCompiler { "/*dynamic=*/{};", kAlignment); } else { - // `root`, `dense`, `pointer` + // `root`, `dense`, `bit_struct` emit(" constant static constexpr int stride = elem_stride * n;"); } emit(""); @@ -149,9 +158,11 @@ class StructCompiler { emit(" nm.mem_alloc = ma;"); emit(" const auto amb_idx = rtm->ambient_indices[{}];", snid); emit(" rep_.init(addr, nm, amb_idx);"); - } else { - // `dense` or `root` + } else if (ty == SNodeType::root || ty == SNodeType::dense) { + // `root`, `dense` emit(" rep_.init(addr);"); + } else { + TI_UNREACHABLE; } emit(" }}\n"); } @@ -185,8 +196,16 @@ class StructCompiler { } void generate_types(const SNode &snode) { + if (snode.is_bit_level) { + // Nothing to generate for bit-level SNodes -- they are part of their + // parent's intrinsic memory. + return; + } + const auto snty = snode.type; const bool is_place = snode.is_place(); - if (!is_place) { + const bool should_gen_cell = !(is_place || (snty == SNodeType::bit_struct)); + if (should_gen_cell) { + // "_ch" is a legacy word for child. The correct notion should be cell. // Generate {snode}_ch const std::string class_name = snode.node_type_name + "_ch"; emit("class {} {{", class_name); @@ -213,7 +232,6 @@ class StructCompiler { } emit(""); const auto &node_name = snode.node_type_name; - const auto snty = snode.type; if (is_place) { const auto dt_name = metal_data_type_name(snode.dt); emit("struct {} {{", node_name); @@ -225,7 +243,21 @@ class StructCompiler { node_name); emit(" : val((device {}*)v) {{}}", dt_name); emit(""); - emit(" device {}* val;", dt_name); + emit(" device {} *val;", dt_name); + emit("}};"); + } else if (snty == SNodeType::bit_struct) { + // TODO: bit_struct and place share a lot in common. + const auto dt_name = metal_data_type_name(DataType(snode.physical_type)); + emit("struct {} {{", node_name); + emit(" // bit_struct"); + emit(" constant static constexpr int stride = sizeof({});", dt_name); + emit(""); + // `bit_struct` constructor + emit(" {}(device byte *b, device Runtime *, device MemoryAllocator *)", + node_name); + emit(" : base(b) {{}}"); + emit(""); + emit(" device byte *base;"); emit("}};"); } else if (snty == SNodeType::dense || snty == SNodeType::root || snty == SNodeType::bitmasked || snty == SNodeType::dynamic || @@ -252,10 +284,8 @@ class StructCompiler { emit(" SNodeRep_{} rep_;", snty_name); emit("}};"); } else { - TI_ERROR( - "SNodeType={} not supported on Metal.\nConsider using " - "ti.init(ti.cpu) if you want to use sparse data structures.", - snode_type_name(snode.type)); + // We have checked the type support previously. + TI_UNREACHABLE; } emit(""); } @@ -264,18 +294,34 @@ class StructCompiler { if (sn->is_place()) { return metal_data_type_bytes(to_metal_type(sn->dt)); } - + if (sn->is_bit_level) { + // A bit-level SNode occupies a fration of a byte. Just return 0 here and + // special handling the bit_* SNode containers. + return 0; + } const int n = get_n(*sn); size_t ch_size = 0; - for (const auto &ch : sn->ch) { - const size_t ch_offset = ch_size; - const auto *ch_sn = ch.get(); - ch_size += compute_snode_size(ch_sn); - if (!ch_sn->is_place()) { - snode_descriptors_.find(ch_sn->id)->second.mem_offset_in_parent = - ch_offset; + if (sn->type == SNodeType::bit_struct) { + // The host side should have inferred all the necessary info of |sn|. + TI_ASSERT(sn->physical_type != nullptr); + ch_size = data_type_size(sn->physical_type); + // |ch_size| should at least be 4 bytes on GPU. In addition, Metal: + // 1. does not support 8-byte data types in the device address space. + // 2. only supports 4-byte atomic integral types (or atomic_bool). + TI_ERROR_IF(ch_size != 4, + "bit_struct physical type must be exactly 32 bits on Metal"); + } else { + for (const auto &ch : sn->ch) { + const size_t ch_offset = ch_size; + const auto *ch_sn = ch.get(); + ch_size += compute_snode_size(ch_sn); + if (!ch_sn->is_place()) { + snode_descriptors_.find(ch_sn->id)->second.mem_offset_in_parent = + ch_offset; + } } } + SNodeDescriptor sn_desc; sn_desc.snode = sn; sn_desc.element_stride = ch_size; @@ -310,8 +356,11 @@ class StructCompiler { emit(" NodeManagerData::ElemIndex ambient_indices[{}];", max_snodes_); emit(" uint32_t rand_seeds[{}];", kNumRandSeeds); emit("}};"); + emit(""); line_appender_.append_raw(shaders::kMetalRuntimeUtilsSourceCode); emit(""); + line_appender_.append_raw(shaders::kMetalSNodeBitPointerSourceCode); + emit(""); } size_t compute_runtime_size() { diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 105fcf5808319..f31d9798acbc3 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1067,7 +1067,7 @@ llvm::Value *CodeGenLLVM::float_to_custom_int(CustomFloatType *cft, llvm::Value *real) { llvm::Value *s = nullptr; - // Compute int(input * (1.0 / scale) + 0.5) + // Compute int(real * (1.0 / scale) + 0.5) auto s_numeric = 1.0 / cft->get_scale(); auto compute_type = cft->get_compute_type(); s = builder->CreateFPCast( diff --git a/taichi/inc/extensions.inc.h b/taichi/inc/extensions.inc.h index c9ee2c7bd20c9..2099c56ec5735 100644 --- a/taichi/inc/extensions.inc.h +++ b/taichi/inc/extensions.inc.h @@ -1,9 +1,10 @@ // Lists of extension features -PER_EXTENSION(sparse) // Sparse data structures -PER_EXTENSION(async_mode) // Asynchronous execution mode -PER_EXTENSION(quant) // Quantization -PER_EXTENSION(data64) // Metal doesn't support 64-bit data buffers yet... -PER_EXTENSION(adstack) // For keeping the history of mutable local variables -PER_EXTENSION(bls) // Block-local storage -PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels -PER_EXTENSION(extfunc) // Invoke external functions or backend source +PER_EXTENSION(sparse) // Sparse data structures +PER_EXTENSION(async_mode) // Asynchronous execution mode +PER_EXTENSION(quant) // Quantization +PER_EXTENSION(quant_basic) // Basic operations in quantization +PER_EXTENSION(data64) // Metal doesn't support 64-bit data buffers yet... +PER_EXTENSION(adstack) // For keeping the history of mutable local variables +PER_EXTENSION(bls) // Block-local storage +PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels +PER_EXTENSION(extfunc) // Invoke external functions or backend source diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index d5eef248ccb60..fb3a2d416b471 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -10,17 +10,19 @@ bool is_extension_supported(Arch arch, Extension ext) { static std::unordered_map> arch2ext = { {Arch::x64, {Extension::sparse, Extension::async_mode, Extension::quant, - Extension::data64, Extension::adstack, Extension::assertion, - Extension::extfunc}}, + Extension::quant_basic, Extension::data64, Extension::adstack, + Extension::assertion, Extension::extfunc}}, {Arch::arm64, {Extension::sparse, Extension::async_mode, Extension::quant, - Extension::data64, Extension::adstack, Extension::assertion}}, + Extension::quant_basic, Extension::data64, Extension::adstack, + Extension::assertion}}, {Arch::cuda, {Extension::sparse, Extension::async_mode, Extension::quant, - Extension::data64, Extension::adstack, Extension::bls, - Extension::assertion}}, + Extension::quant_basic, Extension::data64, Extension::adstack, + Extension::bls, Extension::assertion}}, {Arch::metal, - {Extension::adstack, Extension::assertion, Extension::async_mode}}, + {Extension::adstack, Extension::assertion, Extension::quant_basic, + Extension::async_mode}}, {Arch::opengl, {Extension::extfunc}}, {Arch::cc, {Extension::data64, Extension::extfunc, Extension::adstack}}, }; diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 86ec513a9de29..d0f6015ec09ac 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -1,15 +1,15 @@ #include "kernel.h" -#include "taichi/util/statistics.h" -#include "taichi/common/task.h" -#include "taichi/program/program.h" -#include "taichi/program/async_engine.h" -#include "taichi/codegen/codegen.h" #include "taichi/backends/cuda/cuda_driver.h" +#include "taichi/codegen/codegen.h" +#include "taichi/common/task.h" #include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" -#include "taichi/util/action_recorder.h" +#include "taichi/program/async_engine.h" #include "taichi/program/extension.h" +#include "taichi/program/program.h" +#include "taichi/util/action_recorder.h" +#include "taichi/util/statistics.h" TLANG_NAMESPACE_BEGIN @@ -239,9 +239,11 @@ void Kernel::LaunchContextBuilder::set_arg_raw(int i, uint64 d) { !kernel_->args[i].is_nparray, "Assigning scalar value to numpy array argument is not allowed"); - ActionRecorder::get_instance().record( - "set_arg_raw", {ActionArg("kernel_name", kernel_->name), - ActionArg("arg_id", i), ActionArg("val", (int64)d)}); + if (!kernel_->is_evaluator) { + ActionRecorder::get_instance().record( + "set_arg_raw", {ActionArg("kernel_name", kernel_->name), + ActionArg("arg_id", i), ActionArg("val", (int64)d)}); + } ctx_->set_arg(i, d); } diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index aac6e6ab4788a..9f225427c1749 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -3,7 +3,7 @@ from pytest import approx -@ti.test(require=ti.extension.quant, debug=True) +@ti.test(require=ti.extension.quant_basic, debug=True) def test_simple_array(): ci13 = ti.type_factory.custom_int(13, True) cu19 = ti.type_factory.custom_int(19, False) @@ -37,7 +37,7 @@ def verify_val(): verify_val.__wrapped__() -@ti.test(require=ti.extension.quant, debug=True) +@ti.test(require=ti.extension.quant_basic, debug=True) def test_custom_int_load_and_store(): ci13 = ti.type_factory.custom_int(13, True) cu14 = ti.type_factory.custom_int(14, False) @@ -78,7 +78,7 @@ def verify_val(idx: ti.i32): verify_val.__wrapped__(idx) -@ti.test(require=ti.extension.quant) +@ti.test(require=ti.extension.quant_basic) def test_custom_int_full_struct(): cit = ti.type_factory.custom_int(32, True) x = ti.field(dtype=cit) @@ -137,7 +137,7 @@ def verify_val(test_val: ti.ext_arr()): test_single_bit_struct(32, 32, [10, 10, 12], np.array([11, 19, 2020])) -@ti.test(require=ti.extension.quant, debug=True) +@ti.test(require=[ti.extension.quant_basic, ti.extension.sparse], debug=True) def test_bit_struct_struct_for(): block_size = 16 N = 64 diff --git a/tests/python/test_custom_float.py b/tests/python/test_custom_float.py index c68f4676081fc..e3fded39d4e45 100644 --- a/tests/python/test_custom_float.py +++ b/tests/python/test_custom_float.py @@ -3,7 +3,7 @@ from pytest import approx -@ti.test(require=ti.extension.quant) +@ti.test(require=ti.extension.quant_basic) def test_custom_float(): ci13 = ti.type_factory.custom_int(bits=13) cft = ti.type_factory.custom_float(significand_type=ci13, scale=0.1) @@ -25,7 +25,7 @@ def foo(): assert x[None] == approx(0.7) -@ti.test(require=ti.extension.quant) +@ti.test(require=ti.extension.quant_basic) def test_custom_matrix_rotation(): ci16 = ti.type_factory.custom_int(bits=16) cft = ti.type_factory.custom_float(significand_type=ci16, @@ -53,7 +53,7 @@ def rotate_18_degrees(): assert x[None][1, 1] == approx(0, abs=1e-4) -@ti.test(require=ti.extension.quant) +@ti.test(require=ti.extension.quant_basic) def test_custom_float_implicit_cast(): ci13 = ti.type_factory.custom_int(bits=13) cft = ti.type_factory.custom_float(significand_type=ci13, scale=0.1) @@ -69,7 +69,7 @@ def foo(): assert x[None] == approx(10.0) -@ti.test(require=ti.extension.quant) +@ti.test(require=ti.extension.quant_basic) def test_cache_read_only(): ci15 = ti.type_factory.custom_int(bits=15) cft = ti.type_factory.custom_float(significand_type=ci15, scale=0.1) diff --git a/tests/python/test_custom_int.py b/tests/python/test_custom_int.py index 77f31497ca478..f2cf02451958c 100644 --- a/tests/python/test_custom_int.py +++ b/tests/python/test_custom_int.py @@ -1,7 +1,7 @@ import taichi as ti -@ti.test(require=ti.extension.quant) +@ti.test(require=ti.extension.quant_basic) def test_custom_int_implicit_cast(): ci13 = ti.type_factory.custom_int(13, True) x = ti.field(dtype=ci13) diff --git a/tests/python/test_custom_type_atomics.py b/tests/python/test_custom_type_atomics.py index b25a441124c30..5e684638ada4c 100644 --- a/tests/python/test_custom_type_atomics.py +++ b/tests/python/test_custom_type_atomics.py @@ -2,7 +2,7 @@ from pytest import approx -@ti.test(require=ti.extension.quant, debug=True) +@ti.test(require=ti.extension.quant_basic, debug=True) def test_custom_int_atomics(): ci13 = ti.type_factory.custom_int(13, True) ci5 = ti.type_factory.custom_int(5, True) @@ -36,7 +36,7 @@ def foo(): assert z[None] == 3 -@ti.test(require=ti.extension.quant, debug=True) +@ti.test(require=[ti.extension.quant_basic, ti.extension.data64], debug=True) def test_custom_int_atomics_b64(): ci13 = ti.type_factory.custom_int(13, True) @@ -60,7 +60,7 @@ def foo(): assert x[2] == 315 -@ti.test(require=ti.extension.quant, debug=True) +@ti.test(require=ti.extension.quant_basic, debug=True) def test_custom_float_atomics(): ci13 = ti.type_factory.custom_int(13, True) ci19 = ti.type_factory.custom_int(19, False) diff --git a/tests/python/test_matrix_different_type.py b/tests/python/test_matrix_different_type.py index 3e4095f5c961c..6efdb41a10927 100644 --- a/tests/python/test_matrix_different_type.py +++ b/tests/python/test_matrix_different_type.py @@ -30,7 +30,7 @@ def verify(): # TODO: Support different element types of Matrix on opengl -@ti.test(exclude=ti.opengl) +@ti.test(require=ti.extension.data64, exclude=ti.opengl) def test_matrix(): type_list = [[ti.f32, ti.i32], [ti.i64, ti.f32]] a = ti.Matrix.field(len(type_list), @@ -66,7 +66,7 @@ def verify(): verify() -@ti.test(require=ti.extension.quant) +@ti.test(require=ti.extension.quant_basic) def test_custom_type(): cit1 = ti.type_factory.custom_int(bits=10, signed=True) cft1 = ti.type_factory.custom_float(cit1, scale=0.1)