From 356c4976badece392a87995e594ca005450b8f75 Mon Sep 17 00:00:00 2001 From: Cheng Cao Date: Mon, 20 Jun 2022 13:27:50 -0700 Subject: [PATCH 1/9] Bunch of things --- taichi/backends/vulkan/vulkan_device.cpp | 15 +- taichi/codegen/spirv/kernel_utils.h | 6 +- taichi/codegen/spirv/lib_tiny_ir.h | 262 ++++++++++ .../codegen/spirv/snode_struct_compiler.cpp | 44 ++ taichi/codegen/spirv/snode_struct_compiler.h | 6 + taichi/codegen/spirv/spirv_codegen.cpp | 32 ++ taichi/codegen/spirv/spirv_ir_builder.h | 16 +- taichi/codegen/spirv/spirv_types.cpp | 448 ++++++++++++++++++ taichi/codegen/spirv/spirv_types.h | 240 ++++++++++ taichi/runtime/gfx/runtime.cpp | 164 ++++--- taichi/runtime/gfx/runtime.h | 11 +- 11 files changed, 1151 insertions(+), 93 deletions(-) create mode 100644 taichi/codegen/spirv/lib_tiny_ir.h create mode 100644 taichi/codegen/spirv/spirv_types.cpp create mode 100644 taichi/codegen/spirv/spirv_types.h diff --git a/taichi/backends/vulkan/vulkan_device.cpp b/taichi/backends/vulkan/vulkan_device.cpp index e71d97586da1b..7bce62cc32c3f 100644 --- a/taichi/backends/vulkan/vulkan_device.cpp +++ b/taichi/backends/vulkan/vulkan_device.cpp @@ -1310,7 +1310,6 @@ DeviceAllocation VulkanDevice::allocate_memory(const AllocParams ¶ms) { if (params.usage & AllocUsage::Index) { buffer_info.usage |= VK_BUFFER_USAGE_INDEX_BUFFER_BIT; } - buffer_info.sharingMode = VK_SHARING_MODE_CONCURRENT; uint32_t queue_family_indices[] = {compute_queue_family_index_, graphics_queue_family_index_}; @@ -1351,20 +1350,22 @@ DeviceAllocation VulkanDevice::allocate_memory(const AllocParams ¶ms) { if (params.host_read && params.host_write) { #endif //__APPLE__ // This should be the unified memory on integrated GPUs - alloc_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; - alloc_info.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT | - VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + alloc_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | + VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + alloc_info.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; #ifdef __APPLE__ // weird behavior on apple: if coherent bit is not set, then the memory // writes between map() and unmap() cannot be seen by gpu alloc_info.preferredFlags |= VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; #endif //__APPLE__ } else if (params.host_read) { - alloc_info.usage = VMA_MEMORY_USAGE_GPU_TO_CPU; + alloc_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + alloc_info.preferredFlags = VK_MEMORY_PROPERTY_HOST_CACHED_BIT; } else if (params.host_write) { - alloc_info.usage = VMA_MEMORY_USAGE_CPU_TO_GPU; + alloc_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + alloc_info.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; } else { - alloc_info.usage = VMA_MEMORY_USAGE_GPU_ONLY; + alloc_info.requiredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; } if (get_cap(DeviceCapability::spirv_has_physical_storage_buffer)) { diff --git a/taichi/codegen/spirv/kernel_utils.h b/taichi/codegen/spirv/kernel_utils.h index e4b87a8e7df1b..75c4ada26a4f0 100644 --- a/taichi/codegen/spirv/kernel_utils.h +++ b/taichi/codegen/spirv/kernel_utils.h @@ -6,6 +6,7 @@ #include "taichi/ir/offloaded_task_type.h" #include "taichi/ir/type.h" +#include "taichi/ir/transforms.h" #include "taichi/backends/device.h" namespace taichi { @@ -234,11 +235,14 @@ class KernelContextAttributes { return args_bytes(); } + std::unordered_map arr_access; + TI_IO_DEF(arg_attribs_vec_, ret_attribs_vec_, args_bytes_, rets_bytes_, - extra_args_bytes_); + extra_args_bytes_, + arr_access); private: std::vector arg_attribs_vec_; diff --git a/taichi/codegen/spirv/lib_tiny_ir.h b/taichi/codegen/spirv/lib_tiny_ir.h new file mode 100644 index 0000000000000..4d5e39d029542 --- /dev/null +++ b/taichi/codegen/spirv/lib_tiny_ir.h @@ -0,0 +1,262 @@ +#pragma once + +#include "taichi/common/core.h" + +#include +#include + +namespace taichi { +namespace tinyir { + +template +T ceil_div(T v, T div) { + return (v / div) + (v % div ? 1 : 0); +} + +// Forward decl +class Polymorphic; +class Node; +class Type; +class LayoutContext; +class MemRefElementTypeInterface; +class MemRefAggregateTypeInterface; +class ShapedTypeInterface; +class AggregateTypeInterface; +class PointerTypeInterface; +class Block; +class Visitor; + +class Polymorphic { + public: + virtual ~Polymorphic() { + } + + template + bool is() const { + return dynamic_cast(this) != nullptr; + } + + template + T *as() { + return static_cast(this); + } + + template + const T *as() const { + return static_cast(this); + } + + template + T *cast() { + return dynamic_cast(this); + } + + template + const T *cast() const { + return dynamic_cast(this); + } + + bool operator==(const Polymorphic &other) const { + return typeid(*this) == typeid(other) && is_equal(other); + } + + const bool equals(const Polymorphic *other) const { + return (*this) == (*other); + } + + private: + virtual bool is_equal(const Polymorphic &other) const = 0; +}; + +class Node : public Polymorphic { + public: + using NodeRefs = const std::vector; + + Node() { + } + + virtual ~Node() { + } + + const std::string &debug_name() const { + return debug_name_; + } + + void set_debug_name(const std::string &s) { + debug_name_ = s; + } + + virtual NodeRefs incoming() const { + return {}; + } + + virtual NodeRefs outgoing() const { + return {}; + } + + virtual bool is_leaf() const { + return false; + } + + virtual bool is_tree_node() const { + return false; + } + + private: + virtual bool is_equal(const Polymorphic &other) const { + return false; + } + + std::string debug_name_; +}; + +class Type : public Node { + public: + Type() { + } + + private: + virtual bool is_equal(const Polymorphic &other) const { + return false; + } +}; + +// The default LayoutContext is the standard C layout +class LayoutContext : public Polymorphic { + private: + std::unordered_map size_cache_; + std::unordered_map + alignment_cache_; + std::unordered_map> + elem_offset_cache_; + + public: + void register_size(const MemRefElementTypeInterface *t, size_t size) { + TI_ASSERT(size != 0); + size_cache_[t] = size; + } + + void register_alignment(const MemRefElementTypeInterface *t, size_t size) { + TI_ASSERT(size != 0); + alignment_cache_[t] = size; + } + + void register_aggregate(const MemRefAggregateTypeInterface *t, int num_elem) { + elem_offset_cache_[t] = {}; + elem_offset_cache_[t].resize(num_elem, 0); + } + + void register_elem_offset(const MemRefAggregateTypeInterface *t, + int n, size_t offset) { + TI_ASSERT(elem_offset_cache_.find(t) != elem_offset_cache_.end()); + elem_offset_cache_[t][n] = offset; + } + + // Size or alignment can not be zero + size_t query_size(const MemRefElementTypeInterface *t) { + if (size_cache_.find(t) != size_cache_.end()) { + return size_cache_[t]; + } else { + return 0; + } + } + + size_t query_alignment(const MemRefElementTypeInterface *t) { + if (alignment_cache_.find(t) != alignment_cache_.end()) { + return alignment_cache_[t]; + } else { + return 0; + } + } + + size_t query_elem_offset(const MemRefAggregateTypeInterface *t, int n) { + if (elem_offset_cache_.find(t) != elem_offset_cache_.end()) { + return elem_offset_cache_[t][n]; + } else { + return 0; + } + } + + private: + virtual bool is_equal(const Polymorphic &other) const { + // This is only called when `other` has the same typeid + return true; + } +}; + +class MemRefElementTypeInterface { + public: + virtual size_t memory_size(LayoutContext& ctx) const = 0; + virtual size_t memory_alignment_size(LayoutContext &ctx) const = 0; +}; + +class MemRefAggregateTypeInterface : public MemRefElementTypeInterface { + public: + virtual size_t nth_element_offset(int n, LayoutContext &ctx) const = 0; +}; + +class AggregateTypeInterface { + public: + virtual const Type *nth_element_type(int n) const = 0; + virtual int get_num_elements() const = 0; +}; + +class ShapedTypeInterface { + public: + virtual const Type *element_type() const = 0; + virtual bool is_constant_shape() const = 0; + virtual std::vector get_constant_shape() const = 0; +}; + +class PointerTypeInterface { + public: + virtual const Type *get_pointed_type() const = 0; +}; + +class Block { + public: + template + T *emplace_back(E... args) { + nodes_.push_back(std::make_unique(args...)); + return static_cast(nodes_.back().get()); + } + + template + T *push_back(std::unique_ptr &&val) { + T *ptr = val.get(); + nodes_.push_back(std::move(val)); + return ptr; + } + + const std::vector> &nodes() const { + return nodes_; + } + + private: + std::vector> nodes_; +}; + +class Visitor { + public: + virtual ~Visitor() { + } + + virtual void visit(const Node *node) { + if (node->is()) { + visit_type(node->as()); + } + } + + virtual void visit_type(const Type *type) { + } + + virtual void visit(const Block *block) { + for (auto &n : block->nodes()) { + visit(n.get()); + } + } +}; + +} +} + diff --git a/taichi/codegen/spirv/snode_struct_compiler.cpp b/taichi/codegen/spirv/snode_struct_compiler.cpp index 6be96bb8c8663..330b5fcc40c50 100644 --- a/taichi/codegen/spirv/snode_struct_compiler.cpp +++ b/taichi/codegen/spirv/snode_struct_compiler.cpp @@ -14,11 +14,55 @@ class StructCompiler { result.root = &root; result.root_size = compute_snode_size(&root); result.snode_descriptors = std::move(snode_descriptors_); + /* + result.type_factory = new tinyir::Block; + result.root_type = construct(*result.type_factory, &root); + */ TI_TRACE("RootBuffer size={}", result.root_size); + + /* + std::unique_ptr b = ir_reduce_types(result.type_factory); + + TI_WARN("Original types:\n{}", ir_print_types(result.type_factory)); + + TI_WARN("Reduced types:\n{}", ir_print_types(b.get())); + */ + return result; } private: + + const tinyir::Type *construct(tinyir::Block &ir_module, SNode *sn) { + const tinyir::Type *cell_type = nullptr; + + if (sn->is_place()) { + // Each cell is a single Type + cell_type = translate_ti_primitive(ir_module, sn->dt); + } else { + // Each cell is a struct + std::vector struct_elements; + for (auto &ch : sn->ch) { + const tinyir::Type *elem_type = construct(ir_module, ch.get()); + struct_elements.push_back(elem_type); + } + tinyir::Type *st = ir_module.emplace_back(struct_elements); + st->set_debug_name(fmt::format("{}_{}", snode_type_name(sn->type), sn->get_name())); + cell_type = st; + + if (sn->type == SNodeType::pointer) { + cell_type = ir_module.emplace_back(cell_type); + } + } + + if (sn->num_cells_per_container == 1 || sn->is_scalar()) { + return cell_type; + } else { + return ir_module.emplace_back(cell_type, + sn->num_cells_per_container); + } + } + std::size_t compute_snode_size(SNode *sn) { const bool is_place = sn->is_place(); diff --git a/taichi/codegen/spirv/snode_struct_compiler.h b/taichi/codegen/spirv/snode_struct_compiler.h index f2b4cba6cbe5d..2109ff5af7052 100644 --- a/taichi/codegen/spirv/snode_struct_compiler.h +++ b/taichi/codegen/spirv/snode_struct_compiler.h @@ -5,6 +5,8 @@ #include "taichi/ir/snode.h" +#include "spirv_types.h" + namespace taichi { namespace lang { namespace spirv { @@ -49,6 +51,10 @@ struct CompiledSNodeStructs { const SNode *root{nullptr}; // Map from SNode ID to its descriptor. SNodeDescriptorsMap snode_descriptors; + + // TODO: Use the new type compiler + // tinyir::Block *type_factory; + // const tinyir::Type *root_type; }; CompiledSNodeStructs compile_snode_structs(SNode &root); diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 6cd92965fd5d4..7014719077b4b 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -99,6 +99,7 @@ class TaskCodegen : public IRVisitor { struct Result { std::vector spirv_code; TaskAttributes task_attribs; + std::unordered_map arr_access; }; Result run() { @@ -129,6 +130,7 @@ class TaskCodegen : public IRVisitor { Result res; res.spirv_code = ir_->finalize(); res.task_attribs = std::move(task_attribs_); + res.arr_access = irpass::detect_external_ptr_access_in_task(task_ir_); return res; } @@ -1968,6 +1970,7 @@ class TaskCodegen : public IRVisitor { if (!ctx_attribs_->has_args()) return; + /* std::vector> struct_components_; for (auto &arg : ctx_attribs_->args()) { @@ -1991,6 +1994,31 @@ class TaskCodegen : public IRVisitor { ctx_attribs_->extra_args_mem_offset() + i * 4); } args_struct_type_ = ir_->create_struct_type(struct_components_); + */ + + tinyir::Block blk; + std::vector element_types; + for (auto &arg : ctx_attribs_->args()) { + const tinyir::Type *t; + if (arg.is_array && + device_->get_cap( + DeviceCapability::spirv_has_physical_storage_buffer)) { + t = blk.emplace_back(/*num_bits=*/64, /*is_signed=*/false); + } else { + t = translate_ti_primitive(blk, PrimitiveType::get(arg.dtype)); + } + element_types.push_back(t); + } + const tinyir::Type *i32_type = + blk.emplace_back(/*num_bits=*/32, /*is_signed=*/true); + for (int i = 0; i < ctx_attribs_->extra_args_bytes() / 4; i++) { + element_types.push_back(i32_type); + } + const tinyir::Type *struct_type = blk.emplace_back(element_types); + + STD140LayoutContext layout_ctx; + auto map = ir_translate_to_spirv(&blk, layout_ctx, ir_.get()); + args_struct_type_.id = map[struct_type]; args_buffer_value_ = ir_->uniform_struct_argument(args_struct_type_, 0, 0, "args"); @@ -2199,6 +2227,10 @@ void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs, TaskCodegen cgen(tp); auto task_res = cgen.run(); + for (auto &[id, access] : task_res.arr_access) { + ctx_attribs_.arr_access[id] = ctx_attribs_.arr_access[id] | access; + } + std::vector optimized_spv(task_res.spirv_code); size_t last_size; diff --git a/taichi/codegen/spirv/spirv_ir_builder.h b/taichi/codegen/spirv/spirv_ir_builder.h index 04d4ce00dc95e..943657b0ed5a7 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.h +++ b/taichi/codegen/spirv/spirv_ir_builder.h @@ -454,6 +454,14 @@ class IRBuilder { Value query_value(std::string name) const; // Check whether a value has been evaluated bool check_value_existence(const std::string &name) const; + // Create a new SSA value + Value new_value(const SType &type, ValueKind flag) { + Value val; + val.id = id_counter_++; + val.stype = type; + val.flag = flag; + return val; + } // Support easy access to trivial data types SType i64_type() const { @@ -508,14 +516,6 @@ class IRBuilder { Value rand_i32(Value global_tmp_); private: - Value new_value(const SType &type, ValueKind flag) { - Value val; - val.id = id_counter_++; - val.stype = type; - val.flag = flag; - return val; - } - Value get_const(const SType &dtype, const uint64_t *pvalue, bool cache); SType declare_primitive_type(DataType dt); diff --git a/taichi/codegen/spirv/spirv_types.cpp b/taichi/codegen/spirv/spirv_types.cpp new file mode 100644 index 0000000000000..051315e417e2d --- /dev/null +++ b/taichi/codegen/spirv/spirv_types.cpp @@ -0,0 +1,448 @@ +#include "spirv_types.h" +#include "spirv_ir_builder.h" + +namespace taichi { +namespace lang { +namespace spirv { + +size_t StructType::memory_size(tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_size(this)) { + return s; + } + + ctx.register_aggregate(this, elements_.size()); + + size_t size_head = 0; + int n = 0; + for (const Type *elem : elements_) { + TI_ASSERT(elem->is()); + const MemRefElementTypeInterface *mem_ref_type = + elem->cast(); + size_t elem_size = mem_ref_type->memory_size(ctx); + size_t elem_align = mem_ref_type->memory_alignment_size(ctx); + // First align the head ptr, then add the size + size_head = tinyir::ceil_div(size_head, elem_align) * elem_align; + ctx.register_elem_offset(this, n, size_head); + size_head += elem_size; + n++; + } + + if (ctx.is()) { + // With STD140 layout, the next member is rounded up to the alignment size. Thus we should simply size up the struct to the alignment. + size_t self_alignment = this->memory_alignment_size(ctx); + size_head = tinyir::ceil_div(size_head, self_alignment) * self_alignment; + } + + ctx.register_size(this, size_head); + return size_head; +} + +size_t StructType::memory_alignment_size( + tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_alignment(this)) { + return s; + } + + size_t max_align = 0; + for (const Type *elem : elements_) { + TI_ASSERT(elem->is()); + max_align = std::max(max_align, elem->cast()->memory_alignment_size(ctx)); + } + + if (ctx.is()) { + // With STD140 layout, struct alignment is rounded up to `sizeof(vec4)` + constexpr size_t vec4_size = sizeof(float) * 4; + max_align = tinyir::ceil_div(max_align, vec4_size) * vec4_size; + } + + ctx.register_alignment(this, max_align); + return max_align; +} + +size_t StructType::nth_element_offset(int n, tinyir::LayoutContext &ctx) const { + this->memory_size(ctx); + + return ctx.query_elem_offset(this, n); +} + +SmallVectorType::SmallVectorType(const Type *element_type, int num_elements) + : element_type_(element_type), num_elements_(num_elements) { + TI_ASSERT(num_elements > 1 && num_elements_ <= 4); +} + + +size_t SmallVectorType::memory_size(tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_size(this)) { + return s; + } + + size_t size = element_type_->cast() + ->memory_size(ctx) * num_elements_; + + ctx.register_size(this, size); + return size; +} + +size_t SmallVectorType::memory_alignment_size(tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_alignment(this)) { + return s; + } + + size_t align = + element_type_->cast()->memory_size( + ctx); + + if (ctx.is() || ctx.is()) { + // For STD140 / STD430, small vectors are Power-of-Two aligned + // In C or "Scalar block layout", blocks are aligned to its compoment alignment + if (num_elements_ == 2) { + align *= 2; + } else { + align *= 4; + } + } + + ctx.register_alignment(this, align); + return align; +} + +size_t ArrayType::memory_size(tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_size(this)) { + return s; + } + + size_t elem_align = + element_type_->cast()->memory_alignment_size( + ctx); + + if (ctx.is()) { + // For STD140, arrays element stride equals the base alignment of the array itself + elem_align = this->memory_alignment_size(ctx); + } + size_t size = elem_align * size_; + + ctx.register_size(this, size); + return size; +} + +size_t ArrayType::memory_alignment_size(tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_alignment(this)) { + return s; + } + + size_t elem_align = element_type_->cast() + ->memory_alignment_size(ctx); + + if (ctx.is()) { + // With STD140 layout, array alignment is rounded up to `sizeof(vec4)` + constexpr size_t vec4_size = sizeof(float) * 4; + elem_align = tinyir::ceil_div(elem_align, vec4_size) * vec4_size; + } + + ctx.register_alignment(this, elem_align); + return elem_align; +} + +size_t ArrayType::nth_element_offset(int n, tinyir::LayoutContext &ctx) const { + size_t elem_align = this->memory_alignment_size(ctx); + + return elem_align * n; +} + +bool bitcast_possible(tinyir::Type *a, + tinyir::Type *b, + bool _inverted) { + if (a->is() && b->is()) { + return a->as()->num_bits() == b->as()->num_bits(); + } else if (a->is() && b->is()) { + return a->as()->num_bits() == b->as()->num_bits(); + } else if (!_inverted) { + return bitcast_possible(b, a, true); + } + return false; +} + +const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module, + const DataType t) { + if (t->is()) { + if (t == PrimitiveType::i8) { + return ir_module.emplace_back(/*num_bits=*/8, + /*is_signed=*/true); + } else if (t == PrimitiveType::i16) { + return ir_module.emplace_back(/*num_bits=*/16, + /*is_signed=*/true); + } else if (t == PrimitiveType::i32) { + return ir_module.emplace_back(/*num_bits=*/32, + /*is_signed=*/true); + } else if (t == PrimitiveType::i64) { + return ir_module.emplace_back(/*num_bits=*/64, + /*is_signed=*/true); + } else if (t == PrimitiveType::u8) { + return ir_module.emplace_back(/*num_bits=*/8, + /*is_signed=*/false); + } else if (t == PrimitiveType::u16) { + return ir_module.emplace_back(/*num_bits=*/16, + /*is_signed=*/false); + } else if (t == PrimitiveType::u32) { + return ir_module.emplace_back(/*num_bits=*/32, + /*is_signed=*/false); + } else if (t == PrimitiveType::u64) { + return ir_module.emplace_back(/*num_bits=*/64, + /*is_signed=*/false); + } else if (t == PrimitiveType::f16) { + return ir_module.emplace_back(/*num_bits=*/16); + } else if (t == PrimitiveType::f32) { + return ir_module.emplace_back(/*num_bits=*/32); + } else if (t == PrimitiveType::f64) { + return ir_module.emplace_back(/*num_bits=*/64); + } else { + TI_NOT_IMPLEMENTED; + } + } else { + TI_NOT_IMPLEMENTED; + } +} + +void TypeVisitor::visit_type(const tinyir::Type *type) { + if (type->is()) { + visit_physical_pointer_type(type->as()); + } else if (type->is()) { + visit_small_vector_type(type->as()); + } else if (type->is()) { + visit_array_type(type->as()); + } else if (type->is()) { + visit_struct_type(type->as()); + } else if (type->is()) { + visit_int_type(type->as()); + } else if (type->is()) { + visit_float_type(type->as()); + } +} + +class TypePrinter : public TypeVisitor { + private: + std::string result_; + STD140LayoutContext layout_ctx_; + + uint32_t head_{0}; + std::unordered_map idmap_; + + uint32_t get_id(const tinyir::Type *type) { + if (idmap_.find(type) == idmap_.end()) { + uint32_t id = head_++; + idmap_[type] = id; + return id; + } else { + return idmap_[type]; + } + } + + public: + void visit_int_type(const IntType *type) override { + result_ += fmt::format("T{} = {}int{}_t\n", get_id(type), + type->is_signed() ? "" : "u", type->num_bits()); + } + + void visit_float_type(const FloatType *type) override { + result_ += fmt::format("T{} = float{}_t\n", get_id(type), type->num_bits()); + } + + void visit_physical_pointer_type(const PhysicalPointerType *type) override { + result_ += fmt::format("T{} = T{} *\n", get_id(type), + get_id(type->get_pointed_type())); + } + + void visit_struct_type(const StructType *type) override { + result_ += fmt::format("T{} = struct {{", get_id(type)); + for (int i = 0; i < type->get_num_elements(); i++) { + result_ += fmt::format("T{}, ", get_id(type->nth_element_type(i))); + } + result_ += "}}\n"; + } + + void visit_small_vector_type(const SmallVectorType *type) override { + result_ += fmt::format("T{} = small_vector\n", get_id(type), + get_id(type->element_type()), + type->get_constant_shape()[0]); + } + + void visit_array_type(const ArrayType *type) override { + result_ += fmt::format("T{} = array\n", get_id(type), + get_id(type->element_type()), + type->get_constant_shape()[0]); + } + + static std::string print_types(const tinyir::Block *block) { + TypePrinter p; + p.visit(block); + return p.result_; + } +}; + +std::string ir_print_types(const tinyir::Block *block) { + return TypePrinter::print_types(block); +} + +class TypeReducer : public TypeVisitor { + private: + std::unique_ptr copy_{nullptr}; + std::unordered_map oldptr2newptr_; + + public: + TypeReducer() { + copy_ = std::make_unique(); + } + + const tinyir::Type *check_type(const tinyir::Type *type) { + if (oldptr2newptr_.find(type) != oldptr2newptr_.end()) { + return oldptr2newptr_[type]; + } + for (const auto &t : copy_->nodes()) { + if (t->equals(type)) { + oldptr2newptr_[type] = (const tinyir::Type *)t.get(); + return (const tinyir::Type *)t.get(); + } + } + return nullptr; + } + + void visit_int_type(const IntType *type) override { + if (!check_type(type)) { + oldptr2newptr_[type] = copy_->emplace_back(*type); + } + } + + void visit_float_type(const FloatType *type) override { + if (!check_type(type)) { + oldptr2newptr_[type] = copy_->emplace_back(*type); + } + } + + void visit_physical_pointer_type(const PhysicalPointerType *type) override { + if (!check_type(type)) { + const tinyir::Type *pointed = check_type(type->get_pointed_type()); + TI_ASSERT(pointed); + oldptr2newptr_[type] = copy_->emplace_back(pointed); + } + } + + void visit_struct_type(const StructType *type) override { + if (!check_type(type)) { + std::vector elements; + for (int i = 0; i < type->get_num_elements(); i++) { + const tinyir::Type *elm = check_type(type->nth_element_type(i)); + TI_ASSERT(elm); + elements.push_back(elm); + } + oldptr2newptr_[type] = copy_->emplace_back(elements); + } + } + + void visit_small_vector_type(const SmallVectorType *type) override { + if (!check_type(type)) { + const tinyir::Type *element = check_type(type->element_type()); + TI_ASSERT(element); + oldptr2newptr_[type] = copy_->emplace_back( + element, type->get_constant_shape()[0]); + } + } + + void visit_array_type(const ArrayType *type) override { + if (!check_type(type)) { + const tinyir::Type *element = check_type(type->element_type()); + TI_ASSERT(element); + oldptr2newptr_[type] = copy_->emplace_back( + element, type->get_constant_shape()[0]); + } + } + + static std::unique_ptr reduce(tinyir::Block *blk) { + TypeReducer reducer; + reducer.visit(blk); + return std::move(reducer.copy_); + } +}; + +std::unique_ptr ir_reduce_types(tinyir::Block *blk) { + return TypeReducer::reduce(blk); +} + +class Translate2Spirv : public TypeVisitor { + private: + IRBuilder *spir_builder_{nullptr}; + tinyir::LayoutContext &layout_context_; + + public: + std::unordered_map ir_node_2_spv_value; + + Translate2Spirv(IRBuilder *spir_builder, + tinyir::LayoutContext &layout_context) + : spir_builder_(spir_builder), layout_context_(layout_context) { + } + + void visit_int_type(const IntType *type) override { + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global(spv::OpTypeInt, vt, type->num_bits(), + type->is_signed() ? 1 : 0); + ir_node_2_spv_value[type] = vt.id; + } + + void visit_float_type(const FloatType *type) override { + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global(spv::OpTypeFloat, vt, type->num_bits()); + ir_node_2_spv_value[type] = vt.id; + } + + void visit_physical_pointer_type(const PhysicalPointerType *type) override { + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global( + spv::OpTypePointer, vt, spv::StorageClassPhysicalStorageBuffer, + ir_node_2_spv_value[type->get_pointed_type()]); + ir_node_2_spv_value[type] = vt.id; + } + + void visit_struct_type(const StructType *type) override { + std::vector element_ids; + for (int i = 0; i < type->get_num_elements(); i++) { + element_ids.push_back(ir_node_2_spv_value[type->nth_element_type(i)]); + } + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global(spv::OpTypeStruct, vt, element_ids); + ir_node_2_spv_value[type] = vt.id; + for (int i = 0; i < type->get_num_elements(); i++) { + spir_builder_->decorate(spv::OpMemberDecorate, vt, i, spv::DecorationOffset, + type->nth_element_offset(i, layout_context_)); + } + } + + void visit_small_vector_type(const SmallVectorType *type) override { + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global(spv::OpTypeVector, vt, + ir_node_2_spv_value[type->element_type()], + type->get_constant_shape()[0]); + ir_node_2_spv_value[type] = vt.id; + } + + void visit_array_type(const ArrayType *type) override { + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global(spv::OpTypeArray, vt, + ir_node_2_spv_value[type->element_type()], + type->get_constant_shape()[0]); + ir_node_2_spv_value[type] = vt.id; + spir_builder_->decorate(spv::OpDecorate, vt, spv::DecorationArrayStride, + type->memory_alignment_size(layout_context_)); + } +}; + +std::unordered_map ir_translate_to_spirv( + const tinyir::Block *blk, + tinyir::LayoutContext &layout_ctx, + IRBuilder *spir_builder) { + Translate2Spirv translator(spir_builder, layout_ctx); + translator.visit(blk); + return std::move(translator.ir_node_2_spv_value); +} + +} +} +} diff --git a/taichi/codegen/spirv/spirv_types.h b/taichi/codegen/spirv/spirv_types.h new file mode 100644 index 0000000000000..f0646c0ff6a02 --- /dev/null +++ b/taichi/codegen/spirv/spirv_types.h @@ -0,0 +1,240 @@ +#pragma once + +#include "lib_tiny_ir.h" +#include "taichi/ir/type.h" + +namespace taichi { +namespace lang { +namespace spirv { + +class STD140LayoutContext : public tinyir::LayoutContext {}; +class STD430LayoutContext : public tinyir::LayoutContext {}; + +class IntType : public tinyir::Type, public tinyir::MemRefElementTypeInterface { + public: + IntType(int num_bits, bool is_signed) + : num_bits_(num_bits), is_signed_(is_signed) { + } + + int num_bits() const { + return num_bits_; + } + + bool is_signed() const { + return is_signed_; + } + + size_t memory_size(tinyir::LayoutContext &ctx) const override{ + return tinyir::ceil_div(num_bits(), 8); + } + + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override { + return tinyir::ceil_div(num_bits(), 8); + } + + private: + bool is_equal(const Polymorphic &other) const override { + const IntType &t = (const IntType &)other; + return t.num_bits_ == num_bits_ && t.is_signed_ == is_signed_; + } + + int num_bits_{0}; + bool is_signed_{false}; +}; + +class FloatType : public tinyir::Type, public tinyir::MemRefElementTypeInterface { + public: + FloatType(int num_bits) + : num_bits_(num_bits) { + } + + int num_bits() const { + return num_bits_; + } + + size_t memory_size(tinyir::LayoutContext &ctx) const override { + return tinyir::ceil_div(num_bits(), 8); + } + + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override { + return tinyir::ceil_div(num_bits(), 8); + } + + private: + int num_bits_{0}; + + bool is_equal(const Polymorphic &other) const override { + const FloatType &t = (const FloatType &)other; + return t.num_bits_ == num_bits_; + } +}; + +class PhysicalPointerType : public IntType, + public tinyir::PointerTypeInterface { + public: + PhysicalPointerType(const tinyir::Type *pointed_type) + : IntType(/*num_bits=*/64, /*is_signed=*/false), pointed_type_(pointed_type) { + } + + const tinyir::Type *get_pointed_type() const override { + return pointed_type_; + } + + private: + const tinyir::Type *pointed_type_; + + bool is_equal(const Polymorphic &other) const override { + const PhysicalPointerType &pt = (const PhysicalPointerType &)other; + return IntType::operator==((const IntType &)other) && + pointed_type_->equals(pt.pointed_type_); + } +}; + +class StructType : public tinyir::Type, + public tinyir::AggregateTypeInterface, + public tinyir::MemRefAggregateTypeInterface { + public: + StructType(std::vector &elements) : elements_(elements) { + } + + const tinyir::Type *nth_element_type(int n) const override { + return elements_[n]; + } + + int get_num_elements() const override { + return elements_.size(); + } + + size_t memory_size(tinyir::LayoutContext &ctx) const override; + + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override; + + size_t nth_element_offset(int n, tinyir::LayoutContext &ctx) const override; + + private: + std::vector elements_; + + bool is_equal(const Polymorphic &other) const override { + const StructType &t = (const StructType &)other; + if (t.get_num_elements() != get_num_elements()) { + return false; + } + for (int i = 0; i < get_num_elements(); i++) { + if (!elements_[i]->equals(t.elements_[i])) { + return false; + } + } + return true; + } +}; + +class SmallVectorType : public tinyir::Type, + public tinyir::ShapedTypeInterface, + public tinyir::MemRefElementTypeInterface { + public: + SmallVectorType(const tinyir::Type *element_type, int num_elements); + + const tinyir::Type *element_type() const override { + return element_type_; + } + + bool is_constant_shape() const override { + return true; + } + + std::vector get_constant_shape() const override { + return {size_t(num_elements_)}; + } + + size_t memory_size(tinyir::LayoutContext &ctx) const override; + + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override; + + private: + bool is_equal(const Polymorphic &other) const override { + const SmallVectorType &t = (const SmallVectorType &)other; + return num_elements_ == t.num_elements_ && element_type_->equals(t.element_type_); + } + + const tinyir::Type *element_type_{nullptr}; + int num_elements_{0}; +}; + +class ArrayType : public tinyir::Type, + public tinyir::ShapedTypeInterface, + public tinyir::MemRefAggregateTypeInterface { + public: + ArrayType(const tinyir::Type *element_type, size_t size) + : element_type_(element_type), size_(size) { + } + + const tinyir::Type *element_type() const override { + return element_type_; + } + + bool is_constant_shape() const override { + return true; + } + + std::vector get_constant_shape() const override { + return {size_}; + } + + size_t memory_size(tinyir::LayoutContext &ctx) const override; + + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override; + + size_t nth_element_offset(int n, tinyir::LayoutContext &ctx) const override; + + private: + bool is_equal(const Polymorphic &other) const override { + const ArrayType &t = (const ArrayType &)other; + return size_ == t.size_ && + element_type_->equals(t.element_type_); + } + + const tinyir::Type *element_type_{nullptr}; + size_t size_{0}; +}; + +bool bitcast_possible(tinyir::Type *a, tinyir::Type *b, bool _inverted = false); + +class TypeVisitor : public tinyir::Visitor { + public: + void visit_type(const tinyir::Type *type) override; + + virtual void visit_int_type(const IntType *type) { + } + + virtual void visit_float_type(const FloatType *type) { + } + + virtual void visit_physical_pointer_type(const PhysicalPointerType *type) { + } + + virtual void visit_struct_type(const StructType *type) { + } + + virtual void visit_small_vector_type(const SmallVectorType *type) { + } + + virtual void visit_array_type(const ArrayType *type) { + } +}; + +const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module, + const DataType t); + +std::string ir_print_types(const tinyir::Block *block); + +std::unique_ptr ir_reduce_types(tinyir::Block *blk); + +class IRBuilder; + +std::unordered_map ir_translate_to_spirv( + const tinyir::Block *blk, tinyir::LayoutContext &layout_ctx, + IRBuilder *spir_builder); + +} +} +} diff --git a/taichi/runtime/gfx/runtime.cpp b/taichi/runtime/gfx/runtime.cpp index c62903e65179b..f1f1149cca782 100644 --- a/taichi/runtime/gfx/runtime.cpp +++ b/taichi/runtime/gfx/runtime.cpp @@ -64,12 +64,15 @@ class HostDeviceContextBlitter { RuntimeContext::DevAllocType::kNone && ext_arr_size.at(i)) { // Only need to blit ext arrs (host array) - DeviceAllocation buffer = ext_arrays.at(i); - char *const device_arr_ptr = - reinterpret_cast(device_->map(buffer)); - const void *host_ptr = host_ctx_->get_arg(i); - std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(i)); - device_->unmap(buffer); + uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i)); + if (access & uint32_t(irpass::ExternalPtrAccess::READ)) { + DeviceAllocation buffer = ext_arrays.at(i); + char *const device_arr_ptr = + reinterpret_cast(device_->map(buffer)); + const void *host_ptr = host_ctx_->get_arg(i); + std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(i)); + device_->unmap(buffer); + } } // Substitue in the device address if supported if ((host_ctx_->device_allocation_type[i] == @@ -125,7 +128,7 @@ class HostDeviceContextBlitter { bool device_to_host( CommandList *cmdlist, - const std::unordered_map &ext_arrays, + const std::unordered_map &ext_array_shadows, const std::unordered_map &ext_arr_size, const std::vector &wait_semaphore) { if (ctx_attribs_->empty()) { @@ -159,12 +162,15 @@ class HostDeviceContextBlitter { RuntimeContext::DevAllocType::kNone && ext_arr_size.at(i)) { // Only need to blit ext arrs (host array) - DeviceAllocation buffer = ext_arrays.at(i); - char *const device_arr_ptr = - reinterpret_cast(device_->map(buffer)); - void *host_ptr = host_ctx_->get_arg(i); - std::memcpy(host_ptr, device_arr_ptr, ext_arr_size.at(i)); - device_->unmap(buffer); + uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i)); + if (access & uint32_t(irpass::ExternalPtrAccess::WRITE)) { + DeviceAllocation buffer = ext_array_shadows.at(i); + char *const device_arr_ptr = + reinterpret_cast(device_->map(buffer)); + void *host_ptr = host_ctx_->get_arg(i); + std::memcpy(host_ptr, device_arr_ptr, ext_arr_size.at(i)); + device_->unmap(buffer); + } } } } @@ -315,60 +321,8 @@ size_t CompiledTaichiKernel::get_ret_buffer_size() const { return ret_buffer_size_; } -void CompiledTaichiKernel::generate_command_list( - CommandList *cmdlist, - DeviceAllocationGuard *args_buffer, - DeviceAllocationGuard *ret_buffer, - const std::unordered_map &ext_arrs, - const std::unordered_map &textures) const { - const auto &task_attribs = ti_kernel_attribs_.tasks_attribs; - - for (int i = 0; i < task_attribs.size(); ++i) { - const auto &attribs = task_attribs[i]; - auto vp = pipelines_[i].get(); - const int group_x = (attribs.advisory_total_num_threads + - attribs.advisory_num_threads_per_group - 1) / - attribs.advisory_num_threads_per_group; - ResourceBinder *binder = vp->resource_binder(); - for (auto &bind : attribs.buffer_binds) { - if (bind.buffer.type == BufferType::ExtArr) { - binder->rw_buffer(0, bind.binding, ext_arrs.at(bind.buffer.root_id)); - } else if (args_buffer && bind.buffer.type == BufferType::Args) { - binder->buffer(0, bind.binding, *args_buffer); - } else if (ret_buffer && bind.buffer.type == BufferType::Rets) { - binder->rw_buffer(0, bind.binding, *ret_buffer); - } else { - DeviceAllocation *alloc = input_buffers_.at(bind.buffer); - if (alloc) { - binder->rw_buffer(0, bind.binding, *alloc); - } - } - } - - for (auto &bind : attribs.texture_binds) { - DeviceAllocation texture = textures.at(bind.arg_id); - cmdlist->image_transition(texture, ImageLayout::undefined, - ImageLayout::shader_read); - binder->image(0, bind.binding, texture, {}); - } - - if (attribs.task_type == OffloadedTaskType::listgen) { - for (auto &bind : attribs.buffer_binds) { - if (bind.buffer.type == BufferType::ListGen) { - // FIXME: properlly support multiple list - cmdlist->buffer_fill(input_buffers_.at(bind.buffer)->get_ptr(0), - kBufferSizeEntireSize, - /*data=*/0); - cmdlist->buffer_barrier(*input_buffers_.at(bind.buffer)); - } - } - } - - cmdlist->bind_pipeline(vp); - cmdlist->bind_resources(binder); - cmdlist->dispatch(group_x); - cmdlist->memory_barrier(); - } +Pipeline *CompiledTaichiKernel::get_pipeline(int i) { + return pipelines_[i].get(); } GfxRuntime::GfxRuntime(const Params ¶ms) @@ -441,6 +395,7 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { // `any_arrays` contain both external arrays and NDArrays std::vector> allocated_buffers; std::unordered_map any_arrays; + std::unordered_map any_array_shadows; // `ext_array_size` only holds the size of external arrays (host arrays) // As buffer size information is only needed when it needs to be allocated // and transferred by the host @@ -476,13 +431,26 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { } } else { ext_array_size[i] = host_ctx->array_runtime_sizes[i]; + uint32_t access = uint32_t(ti_kernel->ti_kernel_attribs().ctx_attribs.arr_access.at(i)); + // Alloc ext arr if (ext_array_size[i]) { + bool host_write = access & uint32_t(irpass::ExternalPtrAccess::READ); auto allocated = device_->allocate_memory_unique( - {ext_array_size[i], /*host_write=*/true, /*host_read=*/true, + {ext_array_size[i], host_write, false, /*export_sharing=*/false, AllocUsage::Storage}); any_arrays[i] = *allocated.get(); allocated_buffers.push_back(std::move(allocated)); + + bool host_read = + access & uint32_t(irpass::ExternalPtrAccess::WRITE); + if (host_read) { + auto allocated = device_->allocate_memory_unique( + {ext_array_size[i], false, true, + /*export_sharing=*/false, AllocUsage::Storage}); + any_array_shadows[i] = *allocated.get(); + allocated_buffers.push_back(std::move(allocated)); + } } else { any_arrays[i] = kDeviceNullAllocation; } @@ -502,8 +470,62 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { } // Record commands - ti_kernel->generate_command_list(current_cmdlist_.get(), args_buffer.get(), - ret_buffer.get(), any_arrays, textures); + const auto &task_attribs = ti_kernel->ti_kernel_attribs().tasks_attribs; + + for (int i = 0; i < task_attribs.size(); ++i) { + const auto &attribs = task_attribs[i]; + auto vp = ti_kernel->get_pipeline(i); + const int group_x = (attribs.advisory_total_num_threads + + attribs.advisory_num_threads_per_group - 1) / + attribs.advisory_num_threads_per_group; + ResourceBinder *binder = vp->resource_binder(); + for (auto &bind : attribs.buffer_binds) { + if (bind.buffer.type == BufferType::ExtArr) { + binder->rw_buffer(0, bind.binding, any_arrays.at(bind.buffer.root_id)); + } else if (args_buffer && bind.buffer.type == BufferType::Args) { + binder->buffer(0, bind.binding, *args_buffer); + } else if (ret_buffer && bind.buffer.type == BufferType::Rets) { + binder->rw_buffer(0, bind.binding, *ret_buffer); + } else { + DeviceAllocation *alloc = ti_kernel->get_buffer_bind(bind.buffer); + if (alloc) { + binder->rw_buffer(0, bind.binding, *alloc); + } + } + } + + for (auto &bind : attribs.texture_binds) { + DeviceAllocation texture = textures.at(bind.arg_id); + current_cmdlist_->image_transition(texture, ImageLayout::undefined, + ImageLayout::shader_read); + binder->image(0, bind.binding, texture, {}); + } + + if (attribs.task_type == OffloadedTaskType::listgen) { + for (auto &bind : attribs.buffer_binds) { + if (bind.buffer.type == BufferType::ListGen) { + // FIXME: properlly support multiple list + current_cmdlist_->buffer_fill( + ti_kernel->get_buffer_bind(bind.buffer)->get_ptr(0), + kBufferSizeEntireSize, + /*data=*/0); + current_cmdlist_->buffer_barrier( + *ti_kernel->get_buffer_bind(bind.buffer)); + } + } + } + + current_cmdlist_->bind_pipeline(vp); + current_cmdlist_->bind_resources(binder); + current_cmdlist_->dispatch(group_x); + current_cmdlist_->memory_barrier(); + } + + for (auto &[id, shadow] : any_array_shadows) { + current_cmdlist_->buffer_copy(shadow.get_ptr(0), + any_arrays.at(id).get_ptr(0), + ext_array_size.at(id)); + } // Keep context buffers used in this dispatch if (ti_kernel->get_args_buffer_size()) { @@ -517,7 +539,7 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { std::vector wait_semaphore; if (ctx_blitter) { - if (ctx_blitter->device_to_host(current_cmdlist_.get(), any_arrays, + if (ctx_blitter->device_to_host(current_cmdlist_.get(), any_array_shadows, ext_array_size, wait_semaphore)) { current_cmdlist_ = nullptr; ctx_buffers_.clear(); diff --git a/taichi/runtime/gfx/runtime.h b/taichi/runtime/gfx/runtime.h index 6c7cbc8a50563..844d5084e6472 100644 --- a/taichi/runtime/gfx/runtime.h +++ b/taichi/runtime/gfx/runtime.h @@ -54,12 +54,11 @@ class CompiledTaichiKernel { size_t get_args_buffer_size() const; size_t get_ret_buffer_size() const; - void generate_command_list( - CommandList *cmdlist, - DeviceAllocationGuard *args_buffer, - DeviceAllocationGuard *ret_buffer, - const std::unordered_map &ext_arrs, - const std::unordered_map &textures) const; + Pipeline *get_pipeline(int i); + + DeviceAllocation *get_buffer_bind(const BufferInfo &bind) { + return input_buffers_[bind]; + } private: TaichiKernelAttributes ti_kernel_attribs_; From a3129625817fbde3f6c62001e865553c9adbe81d Mon Sep 17 00:00:00 2001 From: Cheng Cao Date: Mon, 20 Jun 2022 14:06:17 -0700 Subject: [PATCH 2/9] Use type reducer --- taichi/codegen/spirv/spirv_codegen.cpp | 35 +++++------------- taichi/codegen/spirv/spirv_types.cpp | 49 +++++++++++++------------- taichi/codegen/spirv/spirv_types.h | 4 ++- 3 files changed, 36 insertions(+), 52 deletions(-) diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 7014719077b4b..75f1a7840d3aa 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -1970,32 +1970,7 @@ class TaskCodegen : public IRVisitor { if (!ctx_attribs_->has_args()) return; - /* - std::vector> - struct_components_; - for (auto &arg : ctx_attribs_->args()) { - if (arg.is_array && - device_->get_cap( - DeviceCapability::spirv_has_physical_storage_buffer)) { - struct_components_.emplace_back(ir_->u64_type(), - "arg_ptr" + std::to_string(arg.index), - arg.offset_in_mem); - } else { - struct_components_.emplace_back( - ir_->get_primitive_type(PrimitiveType::get(arg.dtype)), - "arg" + std::to_string(arg.index), arg.offset_in_mem); - } - } - // A compromise for use in constants buffer - // where scalar arrays follow very weird packing rules - for (int i = 0; i < ctx_attribs_->extra_args_bytes() / 4; i++) { - struct_components_.emplace_back( - ir_->i32_type(), "extra_args" + std::to_string(i), - ctx_attribs_->extra_args_mem_offset() + i * 4); - } - args_struct_type_ = ir_->create_struct_type(struct_components_); - */ - + // Generate struct IR tinyir::Block blk; std::vector element_types; for (auto &arg : ctx_attribs_->args()) { @@ -2016,8 +1991,14 @@ class TaskCodegen : public IRVisitor { } const tinyir::Type *struct_type = blk.emplace_back(element_types); + // Reduce struct IR + std::unordered_map old2new; + auto reduced_blk = ir_reduce_types(&blk, old2new); + struct_type = old2new[struct_type]; + + // Layout & translate to SPIR-V STD140LayoutContext layout_ctx; - auto map = ir_translate_to_spirv(&blk, layout_ctx, ir_.get()); + auto map = ir_translate_to_spirv(reduced_blk.get(), layout_ctx, ir_.get()); args_struct_type_.id = map[struct_type]; args_buffer_value_ = diff --git a/taichi/codegen/spirv/spirv_types.cpp b/taichi/codegen/spirv/spirv_types.cpp index 051315e417e2d..000680daaa863 100644 --- a/taichi/codegen/spirv/spirv_types.cpp +++ b/taichi/codegen/spirv/spirv_types.cpp @@ -284,22 +284,25 @@ std::string ir_print_types(const tinyir::Block *block) { } class TypeReducer : public TypeVisitor { - private: - std::unique_ptr copy_{nullptr}; - std::unordered_map oldptr2newptr_; - public: - TypeReducer() { - copy_ = std::make_unique(); + std::unique_ptr copy{nullptr}; + std::unordered_map + &oldptr2newptr; + + TypeReducer( + std::unordered_map &old2new) + : oldptr2newptr(old2new) { + copy = std::make_unique(); + old2new.clear(); } const tinyir::Type *check_type(const tinyir::Type *type) { - if (oldptr2newptr_.find(type) != oldptr2newptr_.end()) { - return oldptr2newptr_[type]; + if (oldptr2newptr.find(type) != oldptr2newptr.end()) { + return oldptr2newptr[type]; } - for (const auto &t : copy_->nodes()) { + for (const auto &t : copy->nodes()) { if (t->equals(type)) { - oldptr2newptr_[type] = (const tinyir::Type *)t.get(); + oldptr2newptr[type] = (const tinyir::Type *)t.get(); return (const tinyir::Type *)t.get(); } } @@ -308,13 +311,13 @@ class TypeReducer : public TypeVisitor { void visit_int_type(const IntType *type) override { if (!check_type(type)) { - oldptr2newptr_[type] = copy_->emplace_back(*type); + oldptr2newptr[type] = copy->emplace_back(*type); } } void visit_float_type(const FloatType *type) override { if (!check_type(type)) { - oldptr2newptr_[type] = copy_->emplace_back(*type); + oldptr2newptr[type] = copy->emplace_back(*type); } } @@ -322,7 +325,7 @@ class TypeReducer : public TypeVisitor { if (!check_type(type)) { const tinyir::Type *pointed = check_type(type->get_pointed_type()); TI_ASSERT(pointed); - oldptr2newptr_[type] = copy_->emplace_back(pointed); + oldptr2newptr[type] = copy->emplace_back(pointed); } } @@ -334,7 +337,7 @@ class TypeReducer : public TypeVisitor { TI_ASSERT(elm); elements.push_back(elm); } - oldptr2newptr_[type] = copy_->emplace_back(elements); + oldptr2newptr[type] = copy->emplace_back(elements); } } @@ -342,7 +345,7 @@ class TypeReducer : public TypeVisitor { if (!check_type(type)) { const tinyir::Type *element = check_type(type->element_type()); TI_ASSERT(element); - oldptr2newptr_[type] = copy_->emplace_back( + oldptr2newptr[type] = copy->emplace_back( element, type->get_constant_shape()[0]); } } @@ -351,20 +354,18 @@ class TypeReducer : public TypeVisitor { if (!check_type(type)) { const tinyir::Type *element = check_type(type->element_type()); TI_ASSERT(element); - oldptr2newptr_[type] = copy_->emplace_back( + oldptr2newptr[type] = copy->emplace_back( element, type->get_constant_shape()[0]); } } - - static std::unique_ptr reduce(tinyir::Block *blk) { - TypeReducer reducer; - reducer.visit(blk); - return std::move(reducer.copy_); - } }; -std::unique_ptr ir_reduce_types(tinyir::Block *blk) { - return TypeReducer::reduce(blk); +std::unique_ptr ir_reduce_types( + tinyir::Block *blk, + std::unordered_map &old2new) { + TypeReducer reducer(old2new); + reducer.visit(blk); + return std::move(reducer.copy); } class Translate2Spirv : public TypeVisitor { diff --git a/taichi/codegen/spirv/spirv_types.h b/taichi/codegen/spirv/spirv_types.h index f0646c0ff6a02..1675f432c6e38 100644 --- a/taichi/codegen/spirv/spirv_types.h +++ b/taichi/codegen/spirv/spirv_types.h @@ -227,7 +227,9 @@ const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module, std::string ir_print_types(const tinyir::Block *block); -std::unique_ptr ir_reduce_types(tinyir::Block *blk); +std::unique_ptr ir_reduce_types( + tinyir::Block *blk, + std::unordered_map &old2new); class IRBuilder; From 6bca259300359863e5aa7a33713585cc714fe953 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Jun 2022 21:59:46 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/spirv/lib_tiny_ir.h | 16 +++--- .../codegen/spirv/snode_struct_compiler.cpp | 4 +- taichi/codegen/spirv/spirv_codegen.cpp | 3 +- taichi/codegen/spirv/spirv_types.cpp | 51 ++++++++++--------- taichi/codegen/spirv/spirv_types.h | 31 ++++++----- taichi/runtime/gfx/runtime.cpp | 19 +++---- 6 files changed, 66 insertions(+), 58 deletions(-) diff --git a/taichi/codegen/spirv/lib_tiny_ir.h b/taichi/codegen/spirv/lib_tiny_ir.h index 4d5e39d029542..e5d9663ff36ab 100644 --- a/taichi/codegen/spirv/lib_tiny_ir.h +++ b/taichi/codegen/spirv/lib_tiny_ir.h @@ -71,7 +71,7 @@ class Polymorphic { class Node : public Polymorphic { public: using NodeRefs = const std::vector; - + Node() { } @@ -102,7 +102,7 @@ class Node : public Polymorphic { return false; } - private: + private: virtual bool is_equal(const Polymorphic &other) const { return false; } @@ -147,7 +147,8 @@ class LayoutContext : public Polymorphic { } void register_elem_offset(const MemRefAggregateTypeInterface *t, - int n, size_t offset) { + int n, + size_t offset) { TI_ASSERT(elem_offset_cache_.find(t) != elem_offset_cache_.end()); elem_offset_cache_[t][n] = offset; } @@ -186,7 +187,7 @@ class LayoutContext : public Polymorphic { class MemRefElementTypeInterface { public: - virtual size_t memory_size(LayoutContext& ctx) const = 0; + virtual size_t memory_size(LayoutContext &ctx) const = 0; virtual size_t memory_alignment_size(LayoutContext &ctx) const = 0; }; @@ -215,7 +216,7 @@ class PointerTypeInterface { class Block { public: - template + template T *emplace_back(E... args) { nodes_.push_back(std::make_unique(args...)); return static_cast(nodes_.back().get()); @@ -257,6 +258,5 @@ class Visitor { } }; -} -} - +} // namespace tinyir +} // namespace taichi diff --git a/taichi/codegen/spirv/snode_struct_compiler.cpp b/taichi/codegen/spirv/snode_struct_compiler.cpp index 330b5fcc40c50..a3947bab9120e 100644 --- a/taichi/codegen/spirv/snode_struct_compiler.cpp +++ b/taichi/codegen/spirv/snode_struct_compiler.cpp @@ -32,7 +32,6 @@ class StructCompiler { } private: - const tinyir::Type *construct(tinyir::Block &ir_module, SNode *sn) { const tinyir::Type *cell_type = nullptr; @@ -47,7 +46,8 @@ class StructCompiler { struct_elements.push_back(elem_type); } tinyir::Type *st = ir_module.emplace_back(struct_elements); - st->set_debug_name(fmt::format("{}_{}", snode_type_name(sn->type), sn->get_name())); + st->set_debug_name( + fmt::format("{}_{}", snode_type_name(sn->type), sn->get_name())); cell_type = st; if (sn->type == SNodeType::pointer) { diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 75f1a7840d3aa..8917434992f94 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -1989,7 +1989,8 @@ class TaskCodegen : public IRVisitor { for (int i = 0; i < ctx_attribs_->extra_args_bytes() / 4; i++) { element_types.push_back(i32_type); } - const tinyir::Type *struct_type = blk.emplace_back(element_types); + const tinyir::Type *struct_type = + blk.emplace_back(element_types); // Reduce struct IR std::unordered_map old2new; diff --git a/taichi/codegen/spirv/spirv_types.cpp b/taichi/codegen/spirv/spirv_types.cpp index 000680daaa863..2f55aaa0bf76b 100644 --- a/taichi/codegen/spirv/spirv_types.cpp +++ b/taichi/codegen/spirv/spirv_types.cpp @@ -28,7 +28,8 @@ size_t StructType::memory_size(tinyir::LayoutContext &ctx) const { } if (ctx.is()) { - // With STD140 layout, the next member is rounded up to the alignment size. Thus we should simply size up the struct to the alignment. + // With STD140 layout, the next member is rounded up to the alignment size. + // Thus we should simply size up the struct to the alignment. size_t self_alignment = this->memory_alignment_size(ctx); size_head = tinyir::ceil_div(size_head, self_alignment) * self_alignment; } @@ -37,8 +38,7 @@ size_t StructType::memory_size(tinyir::LayoutContext &ctx) const { return size_head; } -size_t StructType::memory_alignment_size( - tinyir::LayoutContext &ctx) const { +size_t StructType::memory_alignment_size(tinyir::LayoutContext &ctx) const { if (size_t s = ctx.query_alignment(this)) { return s; } @@ -46,7 +46,9 @@ size_t StructType::memory_alignment_size( size_t max_align = 0; for (const Type *elem : elements_) { TI_ASSERT(elem->is()); - max_align = std::max(max_align, elem->cast()->memory_alignment_size(ctx)); + max_align = std::max( + max_align, + elem->cast()->memory_alignment_size(ctx)); } if (ctx.is()) { @@ -70,20 +72,22 @@ SmallVectorType::SmallVectorType(const Type *element_type, int num_elements) TI_ASSERT(num_elements > 1 && num_elements_ <= 4); } - size_t SmallVectorType::memory_size(tinyir::LayoutContext &ctx) const { if (size_t s = ctx.query_size(this)) { return s; } - size_t size = element_type_->cast() - ->memory_size(ctx) * num_elements_; + size_t size = + element_type_->cast()->memory_size( + ctx) * + num_elements_; ctx.register_size(this, size); return size; } -size_t SmallVectorType::memory_alignment_size(tinyir::LayoutContext &ctx) const { +size_t SmallVectorType::memory_alignment_size( + tinyir::LayoutContext &ctx) const { if (size_t s = ctx.query_alignment(this)) { return s; } @@ -94,7 +98,8 @@ size_t SmallVectorType::memory_alignment_size(tinyir::LayoutContext &ctx) const if (ctx.is() || ctx.is()) { // For STD140 / STD430, small vectors are Power-of-Two aligned - // In C or "Scalar block layout", blocks are aligned to its compoment alignment + // In C or "Scalar block layout", blocks are aligned to its compoment + // alignment if (num_elements_ == 2) { align *= 2; } else { @@ -111,12 +116,12 @@ size_t ArrayType::memory_size(tinyir::LayoutContext &ctx) const { return s; } - size_t elem_align = - element_type_->cast()->memory_alignment_size( - ctx); + size_t elem_align = element_type_->cast() + ->memory_alignment_size(ctx); if (ctx.is()) { - // For STD140, arrays element stride equals the base alignment of the array itself + // For STD140, arrays element stride equals the base alignment of the array + // itself elem_align = this->memory_alignment_size(ctx); } size_t size = elem_align * size_; @@ -149,9 +154,7 @@ size_t ArrayType::nth_element_offset(int n, tinyir::LayoutContext &ctx) const { return elem_align * n; } -bool bitcast_possible(tinyir::Type *a, - tinyir::Type *b, - bool _inverted) { +bool bitcast_possible(tinyir::Type *a, tinyir::Type *b, bool _inverted) { if (a->is() && b->is()) { return a->as()->num_bits() == b->as()->num_bits(); } else if (a->is() && b->is()) { @@ -286,8 +289,7 @@ std::string ir_print_types(const tinyir::Block *block) { class TypeReducer : public TypeVisitor { public: std::unique_ptr copy{nullptr}; - std::unordered_map - &oldptr2newptr; + std::unordered_map &oldptr2newptr; TypeReducer( std::unordered_map &old2new) @@ -354,8 +356,8 @@ class TypeReducer : public TypeVisitor { if (!check_type(type)) { const tinyir::Type *element = check_type(type->element_type()); TI_ASSERT(element); - oldptr2newptr[type] = copy->emplace_back( - element, type->get_constant_shape()[0]); + oldptr2newptr[type] = + copy->emplace_back(element, type->get_constant_shape()[0]); } } }; @@ -411,7 +413,8 @@ class Translate2Spirv : public TypeVisitor { spir_builder_->declare_global(spv::OpTypeStruct, vt, element_ids); ir_node_2_spv_value[type] = vt.id; for (int i = 0; i < type->get_num_elements(); i++) { - spir_builder_->decorate(spv::OpMemberDecorate, vt, i, spv::DecorationOffset, + spir_builder_->decorate(spv::OpMemberDecorate, vt, i, + spv::DecorationOffset, type->nth_element_offset(i, layout_context_)); } } @@ -444,6 +447,6 @@ std::unordered_map ir_translate_to_spirv( return std::move(translator.ir_node_2_spv_value); } -} -} -} +} // namespace spirv +} // namespace lang +} // namespace taichi diff --git a/taichi/codegen/spirv/spirv_types.h b/taichi/codegen/spirv/spirv_types.h index 1675f432c6e38..8870d697f6a55 100644 --- a/taichi/codegen/spirv/spirv_types.h +++ b/taichi/codegen/spirv/spirv_types.h @@ -24,10 +24,10 @@ class IntType : public tinyir::Type, public tinyir::MemRefElementTypeInterface { return is_signed_; } - size_t memory_size(tinyir::LayoutContext &ctx) const override{ + size_t memory_size(tinyir::LayoutContext &ctx) const override { return tinyir::ceil_div(num_bits(), 8); } - + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override { return tinyir::ceil_div(num_bits(), 8); } @@ -42,10 +42,10 @@ class IntType : public tinyir::Type, public tinyir::MemRefElementTypeInterface { bool is_signed_{false}; }; -class FloatType : public tinyir::Type, public tinyir::MemRefElementTypeInterface { +class FloatType : public tinyir::Type, + public tinyir::MemRefElementTypeInterface { public: - FloatType(int num_bits) - : num_bits_(num_bits) { + FloatType(int num_bits) : num_bits_(num_bits) { } int num_bits() const { @@ -73,7 +73,8 @@ class PhysicalPointerType : public IntType, public tinyir::PointerTypeInterface { public: PhysicalPointerType(const tinyir::Type *pointed_type) - : IntType(/*num_bits=*/64, /*is_signed=*/false), pointed_type_(pointed_type) { + : IntType(/*num_bits=*/64, /*is_signed=*/false), + pointed_type_(pointed_type) { } const tinyir::Type *get_pointed_type() const override { @@ -94,7 +95,8 @@ class StructType : public tinyir::Type, public tinyir::AggregateTypeInterface, public tinyir::MemRefAggregateTypeInterface { public: - StructType(std::vector &elements) : elements_(elements) { + StructType(std::vector &elements) + : elements_(elements) { } const tinyir::Type *nth_element_type(int n) const override { @@ -153,7 +155,8 @@ class SmallVectorType : public tinyir::Type, private: bool is_equal(const Polymorphic &other) const override { const SmallVectorType &t = (const SmallVectorType &)other; - return num_elements_ == t.num_elements_ && element_type_->equals(t.element_type_); + return num_elements_ == t.num_elements_ && + element_type_->equals(t.element_type_); } const tinyir::Type *element_type_{nullptr}; @@ -189,8 +192,7 @@ class ArrayType : public tinyir::Type, private: bool is_equal(const Polymorphic &other) const override { const ArrayType &t = (const ArrayType &)other; - return size_ == t.size_ && - element_type_->equals(t.element_type_); + return size_ == t.size_ && element_type_->equals(t.element_type_); } const tinyir::Type *element_type_{nullptr}; @@ -234,9 +236,10 @@ std::unique_ptr ir_reduce_types( class IRBuilder; std::unordered_map ir_translate_to_spirv( - const tinyir::Block *blk, tinyir::LayoutContext &layout_ctx, + const tinyir::Block *blk, + tinyir::LayoutContext &layout_ctx, IRBuilder *spir_builder); -} -} -} +} // namespace spirv +} // namespace lang +} // namespace taichi diff --git a/taichi/runtime/gfx/runtime.cpp b/taichi/runtime/gfx/runtime.cpp index f1f1149cca782..460b4d931e131 100644 --- a/taichi/runtime/gfx/runtime.cpp +++ b/taichi/runtime/gfx/runtime.cpp @@ -431,11 +431,13 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { } } else { ext_array_size[i] = host_ctx->array_runtime_sizes[i]; - uint32_t access = uint32_t(ti_kernel->ti_kernel_attribs().ctx_attribs.arr_access.at(i)); - + uint32_t access = uint32_t( + ti_kernel->ti_kernel_attribs().ctx_attribs.arr_access.at(i)); + // Alloc ext arr if (ext_array_size[i]) { - bool host_write = access & uint32_t(irpass::ExternalPtrAccess::READ); + bool host_write = + access & uint32_t(irpass::ExternalPtrAccess::READ); auto allocated = device_->allocate_memory_unique( {ext_array_size[i], host_write, false, /*export_sharing=*/false, AllocUsage::Storage}); @@ -497,7 +499,7 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { for (auto &bind : attribs.texture_binds) { DeviceAllocation texture = textures.at(bind.arg_id); current_cmdlist_->image_transition(texture, ImageLayout::undefined, - ImageLayout::shader_read); + ImageLayout::shader_read); binder->image(0, bind.binding, texture, {}); } @@ -507,8 +509,8 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { // FIXME: properlly support multiple list current_cmdlist_->buffer_fill( ti_kernel->get_buffer_bind(bind.buffer)->get_ptr(0), - kBufferSizeEntireSize, - /*data=*/0); + kBufferSizeEntireSize, + /*data=*/0); current_cmdlist_->buffer_barrier( *ti_kernel->get_buffer_bind(bind.buffer)); } @@ -522,9 +524,8 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { } for (auto &[id, shadow] : any_array_shadows) { - current_cmdlist_->buffer_copy(shadow.get_ptr(0), - any_arrays.at(id).get_ptr(0), - ext_array_size.at(id)); + current_cmdlist_->buffer_copy( + shadow.get_ptr(0), any_arrays.at(id).get_ptr(0), ext_array_size.at(id)); } // Keep context buffers used in this dispatch From f74a5b2f80e46e3db70eb7cd917d89d258985bd4 Mon Sep 17 00:00:00 2001 From: Cheng Cao Date: Tue, 21 Jun 2022 09:38:57 -0700 Subject: [PATCH 4/9] Also enable host read --- tests/cpp/aot/aot_save_load_test.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cpp/aot/aot_save_load_test.cpp b/tests/cpp/aot/aot_save_load_test.cpp index 78d8d16eb0939..5a5998c0a3758 100644 --- a/tests/cpp/aot/aot_save_load_test.cpp +++ b/tests/cpp/aot/aot_save_load_test.cpp @@ -268,6 +268,7 @@ TEST(AotSaveLoad, VulkanNdarray) { const int size = 10; taichi::lang::Device::AllocParams alloc_params; alloc_params.host_write = true; + alloc_params.host_read = true; alloc_params.size = size * sizeof(int); alloc_params.usage = taichi::lang::AllocUsage::Storage; DeviceAllocation devalloc_arr_ = From 9dc9a312e393acdec8513cde43566a9e95306a26 Mon Sep 17 00:00:00 2001 From: Cheng Cao Date: Tue, 21 Jun 2022 17:00:11 -0700 Subject: [PATCH 5/9] Temp fix for args --- taichi/codegen/spirv/kernel_utils.cpp | 10 +++++----- taichi/codegen/spirv/kernel_utils.h | 2 +- taichi/codegen/spirv/spirv_codegen.cpp | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/taichi/codegen/spirv/kernel_utils.cpp b/taichi/codegen/spirv/kernel_utils.cpp index 759061a6823cf..34e1c4ad888e1 100644 --- a/taichi/codegen/spirv/kernel_utils.cpp +++ b/taichi/codegen/spirv/kernel_utils.cpp @@ -48,7 +48,7 @@ std::string TaskAttributes::BufferBind::debug_string() const { TaskAttributes::buffers_name(buffer), binding); } -KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) +KernelContextAttributes::KernelContextAttributes(const Kernel &kernel, Device *device) : args_bytes_(0), rets_bytes_(0), extra_args_bytes_(RuntimeContext::extra_args_size) { @@ -90,12 +90,12 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) ret_attribs_vec_.push_back(ra); } - auto arange_args = [](auto *vec, size_t offset, bool is_ret) -> size_t { + auto arange_args = [](auto *vec, size_t offset, bool is_ret, bool has_buffer_ptr) -> size_t { size_t bytes = offset; for (int i = 0; i < vec->size(); ++i) { auto &attribs = (*vec)[i]; const size_t dt_bytes = - (attribs.is_array && !is_ret) + (attribs.is_array && !is_ret && has_buffer_ptr) ? sizeof(uint64_t) : data_type_size(PrimitiveType::get(attribs.dtype)); // Align bytes to the nearest multiple of dt_bytes @@ -111,12 +111,12 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) }; TI_TRACE("args:"); - args_bytes_ = arange_args(&arg_attribs_vec_, 0, false); + args_bytes_ = arange_args(&arg_attribs_vec_, 0, false, device->get_cap(DeviceCapability::spirv_has_physical_storage_buffer)); // Align to extra args args_bytes_ = (args_bytes_ + 4 - 1) / 4 * 4; TI_TRACE("rets:"); - rets_bytes_ = arange_args(&ret_attribs_vec_, 0, true); + rets_bytes_ = arange_args(&ret_attribs_vec_, 0, true, false); TI_TRACE("sizes: args={} rets={}", args_bytes(), rets_bytes()); TI_ASSERT(has_rets() == (rets_bytes_ > 0)); diff --git a/taichi/codegen/spirv/kernel_utils.h b/taichi/codegen/spirv/kernel_utils.h index 75c4ada26a4f0..65b6332b3937c 100644 --- a/taichi/codegen/spirv/kernel_utils.h +++ b/taichi/codegen/spirv/kernel_utils.h @@ -173,7 +173,7 @@ class KernelContextAttributes { struct RetAttributes : public AttribsBase {}; KernelContextAttributes() = default; - explicit KernelContextAttributes(const Kernel &kernel); + explicit KernelContextAttributes(const Kernel &kernel, Device *device); /** * Whether this kernel has any argument diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 8917434992f94..6a3cc05d75c1e 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -2144,7 +2144,7 @@ static void spriv_message_consumer(spv_message_level_t level, } KernelCodegen::KernelCodegen(const Params ¶ms) - : params_(params), ctx_attribs_(*params.kernel) { + : params_(params), ctx_attribs_(*params.kernel, params.device) { spv_target_env target_env = SPV_ENV_VULKAN_1_0; uint32_t spirv_version = params.device->get_cap(DeviceCapability::spirv_version); From 62fce746ef59246130437b2ce85ae1eab964363b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Jun 2022 00:01:20 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/spirv/kernel_utils.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/taichi/codegen/spirv/kernel_utils.cpp b/taichi/codegen/spirv/kernel_utils.cpp index 34e1c4ad888e1..b63e92b442108 100644 --- a/taichi/codegen/spirv/kernel_utils.cpp +++ b/taichi/codegen/spirv/kernel_utils.cpp @@ -48,7 +48,8 @@ std::string TaskAttributes::BufferBind::debug_string() const { TaskAttributes::buffers_name(buffer), binding); } -KernelContextAttributes::KernelContextAttributes(const Kernel &kernel, Device *device) +KernelContextAttributes::KernelContextAttributes(const Kernel &kernel, + Device *device) : args_bytes_(0), rets_bytes_(0), extra_args_bytes_(RuntimeContext::extra_args_size) { @@ -90,7 +91,8 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel, Device *d ret_attribs_vec_.push_back(ra); } - auto arange_args = [](auto *vec, size_t offset, bool is_ret, bool has_buffer_ptr) -> size_t { + auto arange_args = [](auto *vec, size_t offset, bool is_ret, + bool has_buffer_ptr) -> size_t { size_t bytes = offset; for (int i = 0; i < vec->size(); ++i) { auto &attribs = (*vec)[i]; @@ -111,7 +113,9 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel, Device *d }; TI_TRACE("args:"); - args_bytes_ = arange_args(&arg_attribs_vec_, 0, false, device->get_cap(DeviceCapability::spirv_has_physical_storage_buffer)); + args_bytes_ = arange_args( + &arg_attribs_vec_, 0, false, + device->get_cap(DeviceCapability::spirv_has_physical_storage_buffer)); // Align to extra args args_bytes_ = (args_bytes_ + 4 - 1) / 4 * 4; From 83f1585c328122884f06c5daf3618df8c4454add Mon Sep 17 00:00:00 2001 From: Cheng Cao Date: Tue, 21 Jun 2022 19:01:22 -0700 Subject: [PATCH 7/9] fix arr_access --- taichi/codegen/spirv/kernel_utils.cpp | 1 + taichi/codegen/spirv/kernel_utils.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/spirv/kernel_utils.cpp b/taichi/codegen/spirv/kernel_utils.cpp index b63e92b442108..ea11eba554950 100644 --- a/taichi/codegen/spirv/kernel_utils.cpp +++ b/taichi/codegen/spirv/kernel_utils.cpp @@ -53,6 +53,7 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel, : args_bytes_(0), rets_bytes_(0), extra_args_bytes_(RuntimeContext::extra_args_size) { + arr_access.resize(kernel.args.size(), irpass::ExternalPtrAccess(0)); arg_attribs_vec_.reserve(kernel.args.size()); // TODO: We should be able to limit Kernel args and rets to be primitive types // as well but let's leave that as a followup up PR. diff --git a/taichi/codegen/spirv/kernel_utils.h b/taichi/codegen/spirv/kernel_utils.h index 65b6332b3937c..f5d02a156b5ed 100644 --- a/taichi/codegen/spirv/kernel_utils.h +++ b/taichi/codegen/spirv/kernel_utils.h @@ -235,7 +235,7 @@ class KernelContextAttributes { return args_bytes(); } - std::unordered_map arr_access; + std::vector arr_access; TI_IO_DEF(arg_attribs_vec_, ret_attribs_vec_, From 855689a183b66d37effb8204032c9b35cac4c429 Mon Sep 17 00:00:00 2001 From: Cheng Cao Date: Wed, 22 Jun 2022 09:27:14 -0700 Subject: [PATCH 8/9] Fix nits --- taichi/codegen/spirv/spirv_codegen.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 6a3cc05d75c1e..5c7f2a7fa434b 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -1999,8 +1999,8 @@ class TaskCodegen : public IRVisitor { // Layout & translate to SPIR-V STD140LayoutContext layout_ctx; - auto map = ir_translate_to_spirv(reduced_blk.get(), layout_ctx, ir_.get()); - args_struct_type_.id = map[struct_type]; + auto ir2spirv_map = ir_translate_to_spirv(reduced_blk.get(), layout_ctx, ir_.get()); + args_struct_type_.id = ir2spirv_map[struct_type]; args_buffer_value_ = ir_->uniform_struct_argument(args_struct_type_, 0, 0, "args"); From 5f03f121c4d0313d02d0e024b1dcb6050823e50e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Jun 2022 16:28:24 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/spirv/spirv_codegen.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 5c7f2a7fa434b..5bcf4dfb4937a 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -1999,7 +1999,8 @@ class TaskCodegen : public IRVisitor { // Layout & translate to SPIR-V STD140LayoutContext layout_ctx; - auto ir2spirv_map = ir_translate_to_spirv(reduced_blk.get(), layout_ctx, ir_.get()); + auto ir2spirv_map = + ir_translate_to_spirv(reduced_blk.get(), layout_ctx, ir_.get()); args_struct_type_.id = ir2spirv_map[struct_type]; args_buffer_value_ =