diff --git a/taichi/backends/vulkan/vulkan_device.cpp b/taichi/backends/vulkan/vulkan_device.cpp index e71d97586da1b..7bce62cc32c3f 100644 --- a/taichi/backends/vulkan/vulkan_device.cpp +++ b/taichi/backends/vulkan/vulkan_device.cpp @@ -1310,7 +1310,6 @@ DeviceAllocation VulkanDevice::allocate_memory(const AllocParams ¶ms) { if (params.usage & AllocUsage::Index) { buffer_info.usage |= VK_BUFFER_USAGE_INDEX_BUFFER_BIT; } - buffer_info.sharingMode = VK_SHARING_MODE_CONCURRENT; uint32_t queue_family_indices[] = {compute_queue_family_index_, graphics_queue_family_index_}; @@ -1351,20 +1350,22 @@ DeviceAllocation VulkanDevice::allocate_memory(const AllocParams ¶ms) { if (params.host_read && params.host_write) { #endif //__APPLE__ // This should be the unified memory on integrated GPUs - alloc_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; - alloc_info.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT | - VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + alloc_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | + VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + alloc_info.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; #ifdef __APPLE__ // weird behavior on apple: if coherent bit is not set, then the memory // writes between map() and unmap() cannot be seen by gpu alloc_info.preferredFlags |= VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; #endif //__APPLE__ } else if (params.host_read) { - alloc_info.usage = VMA_MEMORY_USAGE_GPU_TO_CPU; + alloc_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + alloc_info.preferredFlags = VK_MEMORY_PROPERTY_HOST_CACHED_BIT; } else if (params.host_write) { - alloc_info.usage = VMA_MEMORY_USAGE_CPU_TO_GPU; + alloc_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + alloc_info.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; } else { - alloc_info.usage = VMA_MEMORY_USAGE_GPU_ONLY; + alloc_info.requiredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; } if (get_cap(DeviceCapability::spirv_has_physical_storage_buffer)) { diff --git a/taichi/codegen/spirv/kernel_utils.cpp b/taichi/codegen/spirv/kernel_utils.cpp index 759061a6823cf..ea11eba554950 100644 --- a/taichi/codegen/spirv/kernel_utils.cpp +++ b/taichi/codegen/spirv/kernel_utils.cpp @@ -48,10 +48,12 @@ std::string TaskAttributes::BufferBind::debug_string() const { TaskAttributes::buffers_name(buffer), binding); } -KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) +KernelContextAttributes::KernelContextAttributes(const Kernel &kernel, + Device *device) : args_bytes_(0), rets_bytes_(0), extra_args_bytes_(RuntimeContext::extra_args_size) { + arr_access.resize(kernel.args.size(), irpass::ExternalPtrAccess(0)); arg_attribs_vec_.reserve(kernel.args.size()); // TODO: We should be able to limit Kernel args and rets to be primitive types // as well but let's leave that as a followup up PR. @@ -90,12 +92,13 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) ret_attribs_vec_.push_back(ra); } - auto arange_args = [](auto *vec, size_t offset, bool is_ret) -> size_t { + auto arange_args = [](auto *vec, size_t offset, bool is_ret, + bool has_buffer_ptr) -> size_t { size_t bytes = offset; for (int i = 0; i < vec->size(); ++i) { auto &attribs = (*vec)[i]; const size_t dt_bytes = - (attribs.is_array && !is_ret) + (attribs.is_array && !is_ret && has_buffer_ptr) ? sizeof(uint64_t) : data_type_size(PrimitiveType::get(attribs.dtype)); // Align bytes to the nearest multiple of dt_bytes @@ -111,12 +114,14 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) }; TI_TRACE("args:"); - args_bytes_ = arange_args(&arg_attribs_vec_, 0, false); + args_bytes_ = arange_args( + &arg_attribs_vec_, 0, false, + device->get_cap(DeviceCapability::spirv_has_physical_storage_buffer)); // Align to extra args args_bytes_ = (args_bytes_ + 4 - 1) / 4 * 4; TI_TRACE("rets:"); - rets_bytes_ = arange_args(&ret_attribs_vec_, 0, true); + rets_bytes_ = arange_args(&ret_attribs_vec_, 0, true, false); TI_TRACE("sizes: args={} rets={}", args_bytes(), rets_bytes()); TI_ASSERT(has_rets() == (rets_bytes_ > 0)); diff --git a/taichi/codegen/spirv/kernel_utils.h b/taichi/codegen/spirv/kernel_utils.h index e4b87a8e7df1b..f5d02a156b5ed 100644 --- a/taichi/codegen/spirv/kernel_utils.h +++ b/taichi/codegen/spirv/kernel_utils.h @@ -6,6 +6,7 @@ #include "taichi/ir/offloaded_task_type.h" #include "taichi/ir/type.h" +#include "taichi/ir/transforms.h" #include "taichi/backends/device.h" namespace taichi { @@ -172,7 +173,7 @@ class KernelContextAttributes { struct RetAttributes : public AttribsBase {}; KernelContextAttributes() = default; - explicit KernelContextAttributes(const Kernel &kernel); + explicit KernelContextAttributes(const Kernel &kernel, Device *device); /** * Whether this kernel has any argument @@ -234,11 +235,14 @@ class KernelContextAttributes { return args_bytes(); } + std::vector arr_access; + TI_IO_DEF(arg_attribs_vec_, ret_attribs_vec_, args_bytes_, rets_bytes_, - extra_args_bytes_); + extra_args_bytes_, + arr_access); private: std::vector arg_attribs_vec_; diff --git a/taichi/codegen/spirv/lib_tiny_ir.h b/taichi/codegen/spirv/lib_tiny_ir.h new file mode 100644 index 0000000000000..e5d9663ff36ab --- /dev/null +++ b/taichi/codegen/spirv/lib_tiny_ir.h @@ -0,0 +1,262 @@ +#pragma once + +#include "taichi/common/core.h" + +#include +#include + +namespace taichi { +namespace tinyir { + +template +T ceil_div(T v, T div) { + return (v / div) + (v % div ? 1 : 0); +} + +// Forward decl +class Polymorphic; +class Node; +class Type; +class LayoutContext; +class MemRefElementTypeInterface; +class MemRefAggregateTypeInterface; +class ShapedTypeInterface; +class AggregateTypeInterface; +class PointerTypeInterface; +class Block; +class Visitor; + +class Polymorphic { + public: + virtual ~Polymorphic() { + } + + template + bool is() const { + return dynamic_cast(this) != nullptr; + } + + template + T *as() { + return static_cast(this); + } + + template + const T *as() const { + return static_cast(this); + } + + template + T *cast() { + return dynamic_cast(this); + } + + template + const T *cast() const { + return dynamic_cast(this); + } + + bool operator==(const Polymorphic &other) const { + return typeid(*this) == typeid(other) && is_equal(other); + } + + const bool equals(const Polymorphic *other) const { + return (*this) == (*other); + } + + private: + virtual bool is_equal(const Polymorphic &other) const = 0; +}; + +class Node : public Polymorphic { + public: + using NodeRefs = const std::vector; + + Node() { + } + + virtual ~Node() { + } + + const std::string &debug_name() const { + return debug_name_; + } + + void set_debug_name(const std::string &s) { + debug_name_ = s; + } + + virtual NodeRefs incoming() const { + return {}; + } + + virtual NodeRefs outgoing() const { + return {}; + } + + virtual bool is_leaf() const { + return false; + } + + virtual bool is_tree_node() const { + return false; + } + + private: + virtual bool is_equal(const Polymorphic &other) const { + return false; + } + + std::string debug_name_; +}; + +class Type : public Node { + public: + Type() { + } + + private: + virtual bool is_equal(const Polymorphic &other) const { + return false; + } +}; + +// The default LayoutContext is the standard C layout +class LayoutContext : public Polymorphic { + private: + std::unordered_map size_cache_; + std::unordered_map + alignment_cache_; + std::unordered_map> + elem_offset_cache_; + + public: + void register_size(const MemRefElementTypeInterface *t, size_t size) { + TI_ASSERT(size != 0); + size_cache_[t] = size; + } + + void register_alignment(const MemRefElementTypeInterface *t, size_t size) { + TI_ASSERT(size != 0); + alignment_cache_[t] = size; + } + + void register_aggregate(const MemRefAggregateTypeInterface *t, int num_elem) { + elem_offset_cache_[t] = {}; + elem_offset_cache_[t].resize(num_elem, 0); + } + + void register_elem_offset(const MemRefAggregateTypeInterface *t, + int n, + size_t offset) { + TI_ASSERT(elem_offset_cache_.find(t) != elem_offset_cache_.end()); + elem_offset_cache_[t][n] = offset; + } + + // Size or alignment can not be zero + size_t query_size(const MemRefElementTypeInterface *t) { + if (size_cache_.find(t) != size_cache_.end()) { + return size_cache_[t]; + } else { + return 0; + } + } + + size_t query_alignment(const MemRefElementTypeInterface *t) { + if (alignment_cache_.find(t) != alignment_cache_.end()) { + return alignment_cache_[t]; + } else { + return 0; + } + } + + size_t query_elem_offset(const MemRefAggregateTypeInterface *t, int n) { + if (elem_offset_cache_.find(t) != elem_offset_cache_.end()) { + return elem_offset_cache_[t][n]; + } else { + return 0; + } + } + + private: + virtual bool is_equal(const Polymorphic &other) const { + // This is only called when `other` has the same typeid + return true; + } +}; + +class MemRefElementTypeInterface { + public: + virtual size_t memory_size(LayoutContext &ctx) const = 0; + virtual size_t memory_alignment_size(LayoutContext &ctx) const = 0; +}; + +class MemRefAggregateTypeInterface : public MemRefElementTypeInterface { + public: + virtual size_t nth_element_offset(int n, LayoutContext &ctx) const = 0; +}; + +class AggregateTypeInterface { + public: + virtual const Type *nth_element_type(int n) const = 0; + virtual int get_num_elements() const = 0; +}; + +class ShapedTypeInterface { + public: + virtual const Type *element_type() const = 0; + virtual bool is_constant_shape() const = 0; + virtual std::vector get_constant_shape() const = 0; +}; + +class PointerTypeInterface { + public: + virtual const Type *get_pointed_type() const = 0; +}; + +class Block { + public: + template + T *emplace_back(E... args) { + nodes_.push_back(std::make_unique(args...)); + return static_cast(nodes_.back().get()); + } + + template + T *push_back(std::unique_ptr &&val) { + T *ptr = val.get(); + nodes_.push_back(std::move(val)); + return ptr; + } + + const std::vector> &nodes() const { + return nodes_; + } + + private: + std::vector> nodes_; +}; + +class Visitor { + public: + virtual ~Visitor() { + } + + virtual void visit(const Node *node) { + if (node->is()) { + visit_type(node->as()); + } + } + + virtual void visit_type(const Type *type) { + } + + virtual void visit(const Block *block) { + for (auto &n : block->nodes()) { + visit(n.get()); + } + } +}; + +} // namespace tinyir +} // namespace taichi diff --git a/taichi/codegen/spirv/snode_struct_compiler.cpp b/taichi/codegen/spirv/snode_struct_compiler.cpp index 6be96bb8c8663..a3947bab9120e 100644 --- a/taichi/codegen/spirv/snode_struct_compiler.cpp +++ b/taichi/codegen/spirv/snode_struct_compiler.cpp @@ -14,11 +14,55 @@ class StructCompiler { result.root = &root; result.root_size = compute_snode_size(&root); result.snode_descriptors = std::move(snode_descriptors_); + /* + result.type_factory = new tinyir::Block; + result.root_type = construct(*result.type_factory, &root); + */ TI_TRACE("RootBuffer size={}", result.root_size); + + /* + std::unique_ptr b = ir_reduce_types(result.type_factory); + + TI_WARN("Original types:\n{}", ir_print_types(result.type_factory)); + + TI_WARN("Reduced types:\n{}", ir_print_types(b.get())); + */ + return result; } private: + const tinyir::Type *construct(tinyir::Block &ir_module, SNode *sn) { + const tinyir::Type *cell_type = nullptr; + + if (sn->is_place()) { + // Each cell is a single Type + cell_type = translate_ti_primitive(ir_module, sn->dt); + } else { + // Each cell is a struct + std::vector struct_elements; + for (auto &ch : sn->ch) { + const tinyir::Type *elem_type = construct(ir_module, ch.get()); + struct_elements.push_back(elem_type); + } + tinyir::Type *st = ir_module.emplace_back(struct_elements); + st->set_debug_name( + fmt::format("{}_{}", snode_type_name(sn->type), sn->get_name())); + cell_type = st; + + if (sn->type == SNodeType::pointer) { + cell_type = ir_module.emplace_back(cell_type); + } + } + + if (sn->num_cells_per_container == 1 || sn->is_scalar()) { + return cell_type; + } else { + return ir_module.emplace_back(cell_type, + sn->num_cells_per_container); + } + } + std::size_t compute_snode_size(SNode *sn) { const bool is_place = sn->is_place(); diff --git a/taichi/codegen/spirv/snode_struct_compiler.h b/taichi/codegen/spirv/snode_struct_compiler.h index f2b4cba6cbe5d..2109ff5af7052 100644 --- a/taichi/codegen/spirv/snode_struct_compiler.h +++ b/taichi/codegen/spirv/snode_struct_compiler.h @@ -5,6 +5,8 @@ #include "taichi/ir/snode.h" +#include "spirv_types.h" + namespace taichi { namespace lang { namespace spirv { @@ -49,6 +51,10 @@ struct CompiledSNodeStructs { const SNode *root{nullptr}; // Map from SNode ID to its descriptor. SNodeDescriptorsMap snode_descriptors; + + // TODO: Use the new type compiler + // tinyir::Block *type_factory; + // const tinyir::Type *root_type; }; CompiledSNodeStructs compile_snode_structs(SNode &root); diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 6cd92965fd5d4..5bcf4dfb4937a 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -99,6 +99,7 @@ class TaskCodegen : public IRVisitor { struct Result { std::vector spirv_code; TaskAttributes task_attribs; + std::unordered_map arr_access; }; Result run() { @@ -129,6 +130,7 @@ class TaskCodegen : public IRVisitor { Result res; res.spirv_code = ir_->finalize(); res.task_attribs = std::move(task_attribs_); + res.arr_access = irpass::detect_external_ptr_access_in_task(task_ir_); return res; } @@ -1968,29 +1970,38 @@ class TaskCodegen : public IRVisitor { if (!ctx_attribs_->has_args()) return; - std::vector> - struct_components_; + // Generate struct IR + tinyir::Block blk; + std::vector element_types; for (auto &arg : ctx_attribs_->args()) { + const tinyir::Type *t; if (arg.is_array && device_->get_cap( DeviceCapability::spirv_has_physical_storage_buffer)) { - struct_components_.emplace_back(ir_->u64_type(), - "arg_ptr" + std::to_string(arg.index), - arg.offset_in_mem); + t = blk.emplace_back(/*num_bits=*/64, /*is_signed=*/false); } else { - struct_components_.emplace_back( - ir_->get_primitive_type(PrimitiveType::get(arg.dtype)), - "arg" + std::to_string(arg.index), arg.offset_in_mem); + t = translate_ti_primitive(blk, PrimitiveType::get(arg.dtype)); } + element_types.push_back(t); } - // A compromise for use in constants buffer - // where scalar arrays follow very weird packing rules + const tinyir::Type *i32_type = + blk.emplace_back(/*num_bits=*/32, /*is_signed=*/true); for (int i = 0; i < ctx_attribs_->extra_args_bytes() / 4; i++) { - struct_components_.emplace_back( - ir_->i32_type(), "extra_args" + std::to_string(i), - ctx_attribs_->extra_args_mem_offset() + i * 4); + element_types.push_back(i32_type); } - args_struct_type_ = ir_->create_struct_type(struct_components_); + const tinyir::Type *struct_type = + blk.emplace_back(element_types); + + // Reduce struct IR + std::unordered_map old2new; + auto reduced_blk = ir_reduce_types(&blk, old2new); + struct_type = old2new[struct_type]; + + // Layout & translate to SPIR-V + STD140LayoutContext layout_ctx; + auto ir2spirv_map = + ir_translate_to_spirv(reduced_blk.get(), layout_ctx, ir_.get()); + args_struct_type_.id = ir2spirv_map[struct_type]; args_buffer_value_ = ir_->uniform_struct_argument(args_struct_type_, 0, 0, "args"); @@ -2134,7 +2145,7 @@ static void spriv_message_consumer(spv_message_level_t level, } KernelCodegen::KernelCodegen(const Params ¶ms) - : params_(params), ctx_attribs_(*params.kernel) { + : params_(params), ctx_attribs_(*params.kernel, params.device) { spv_target_env target_env = SPV_ENV_VULKAN_1_0; uint32_t spirv_version = params.device->get_cap(DeviceCapability::spirv_version); @@ -2199,6 +2210,10 @@ void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs, TaskCodegen cgen(tp); auto task_res = cgen.run(); + for (auto &[id, access] : task_res.arr_access) { + ctx_attribs_.arr_access[id] = ctx_attribs_.arr_access[id] | access; + } + std::vector optimized_spv(task_res.spirv_code); size_t last_size; diff --git a/taichi/codegen/spirv/spirv_ir_builder.h b/taichi/codegen/spirv/spirv_ir_builder.h index 04d4ce00dc95e..943657b0ed5a7 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.h +++ b/taichi/codegen/spirv/spirv_ir_builder.h @@ -454,6 +454,14 @@ class IRBuilder { Value query_value(std::string name) const; // Check whether a value has been evaluated bool check_value_existence(const std::string &name) const; + // Create a new SSA value + Value new_value(const SType &type, ValueKind flag) { + Value val; + val.id = id_counter_++; + val.stype = type; + val.flag = flag; + return val; + } // Support easy access to trivial data types SType i64_type() const { @@ -508,14 +516,6 @@ class IRBuilder { Value rand_i32(Value global_tmp_); private: - Value new_value(const SType &type, ValueKind flag) { - Value val; - val.id = id_counter_++; - val.stype = type; - val.flag = flag; - return val; - } - Value get_const(const SType &dtype, const uint64_t *pvalue, bool cache); SType declare_primitive_type(DataType dt); diff --git a/taichi/codegen/spirv/spirv_types.cpp b/taichi/codegen/spirv/spirv_types.cpp new file mode 100644 index 0000000000000..2f55aaa0bf76b --- /dev/null +++ b/taichi/codegen/spirv/spirv_types.cpp @@ -0,0 +1,452 @@ +#include "spirv_types.h" +#include "spirv_ir_builder.h" + +namespace taichi { +namespace lang { +namespace spirv { + +size_t StructType::memory_size(tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_size(this)) { + return s; + } + + ctx.register_aggregate(this, elements_.size()); + + size_t size_head = 0; + int n = 0; + for (const Type *elem : elements_) { + TI_ASSERT(elem->is()); + const MemRefElementTypeInterface *mem_ref_type = + elem->cast(); + size_t elem_size = mem_ref_type->memory_size(ctx); + size_t elem_align = mem_ref_type->memory_alignment_size(ctx); + // First align the head ptr, then add the size + size_head = tinyir::ceil_div(size_head, elem_align) * elem_align; + ctx.register_elem_offset(this, n, size_head); + size_head += elem_size; + n++; + } + + if (ctx.is()) { + // With STD140 layout, the next member is rounded up to the alignment size. + // Thus we should simply size up the struct to the alignment. + size_t self_alignment = this->memory_alignment_size(ctx); + size_head = tinyir::ceil_div(size_head, self_alignment) * self_alignment; + } + + ctx.register_size(this, size_head); + return size_head; +} + +size_t StructType::memory_alignment_size(tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_alignment(this)) { + return s; + } + + size_t max_align = 0; + for (const Type *elem : elements_) { + TI_ASSERT(elem->is()); + max_align = std::max( + max_align, + elem->cast()->memory_alignment_size(ctx)); + } + + if (ctx.is()) { + // With STD140 layout, struct alignment is rounded up to `sizeof(vec4)` + constexpr size_t vec4_size = sizeof(float) * 4; + max_align = tinyir::ceil_div(max_align, vec4_size) * vec4_size; + } + + ctx.register_alignment(this, max_align); + return max_align; +} + +size_t StructType::nth_element_offset(int n, tinyir::LayoutContext &ctx) const { + this->memory_size(ctx); + + return ctx.query_elem_offset(this, n); +} + +SmallVectorType::SmallVectorType(const Type *element_type, int num_elements) + : element_type_(element_type), num_elements_(num_elements) { + TI_ASSERT(num_elements > 1 && num_elements_ <= 4); +} + +size_t SmallVectorType::memory_size(tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_size(this)) { + return s; + } + + size_t size = + element_type_->cast()->memory_size( + ctx) * + num_elements_; + + ctx.register_size(this, size); + return size; +} + +size_t SmallVectorType::memory_alignment_size( + tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_alignment(this)) { + return s; + } + + size_t align = + element_type_->cast()->memory_size( + ctx); + + if (ctx.is() || ctx.is()) { + // For STD140 / STD430, small vectors are Power-of-Two aligned + // In C or "Scalar block layout", blocks are aligned to its compoment + // alignment + if (num_elements_ == 2) { + align *= 2; + } else { + align *= 4; + } + } + + ctx.register_alignment(this, align); + return align; +} + +size_t ArrayType::memory_size(tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_size(this)) { + return s; + } + + size_t elem_align = element_type_->cast() + ->memory_alignment_size(ctx); + + if (ctx.is()) { + // For STD140, arrays element stride equals the base alignment of the array + // itself + elem_align = this->memory_alignment_size(ctx); + } + size_t size = elem_align * size_; + + ctx.register_size(this, size); + return size; +} + +size_t ArrayType::memory_alignment_size(tinyir::LayoutContext &ctx) const { + if (size_t s = ctx.query_alignment(this)) { + return s; + } + + size_t elem_align = element_type_->cast() + ->memory_alignment_size(ctx); + + if (ctx.is()) { + // With STD140 layout, array alignment is rounded up to `sizeof(vec4)` + constexpr size_t vec4_size = sizeof(float) * 4; + elem_align = tinyir::ceil_div(elem_align, vec4_size) * vec4_size; + } + + ctx.register_alignment(this, elem_align); + return elem_align; +} + +size_t ArrayType::nth_element_offset(int n, tinyir::LayoutContext &ctx) const { + size_t elem_align = this->memory_alignment_size(ctx); + + return elem_align * n; +} + +bool bitcast_possible(tinyir::Type *a, tinyir::Type *b, bool _inverted) { + if (a->is() && b->is()) { + return a->as()->num_bits() == b->as()->num_bits(); + } else if (a->is() && b->is()) { + return a->as()->num_bits() == b->as()->num_bits(); + } else if (!_inverted) { + return bitcast_possible(b, a, true); + } + return false; +} + +const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module, + const DataType t) { + if (t->is()) { + if (t == PrimitiveType::i8) { + return ir_module.emplace_back(/*num_bits=*/8, + /*is_signed=*/true); + } else if (t == PrimitiveType::i16) { + return ir_module.emplace_back(/*num_bits=*/16, + /*is_signed=*/true); + } else if (t == PrimitiveType::i32) { + return ir_module.emplace_back(/*num_bits=*/32, + /*is_signed=*/true); + } else if (t == PrimitiveType::i64) { + return ir_module.emplace_back(/*num_bits=*/64, + /*is_signed=*/true); + } else if (t == PrimitiveType::u8) { + return ir_module.emplace_back(/*num_bits=*/8, + /*is_signed=*/false); + } else if (t == PrimitiveType::u16) { + return ir_module.emplace_back(/*num_bits=*/16, + /*is_signed=*/false); + } else if (t == PrimitiveType::u32) { + return ir_module.emplace_back(/*num_bits=*/32, + /*is_signed=*/false); + } else if (t == PrimitiveType::u64) { + return ir_module.emplace_back(/*num_bits=*/64, + /*is_signed=*/false); + } else if (t == PrimitiveType::f16) { + return ir_module.emplace_back(/*num_bits=*/16); + } else if (t == PrimitiveType::f32) { + return ir_module.emplace_back(/*num_bits=*/32); + } else if (t == PrimitiveType::f64) { + return ir_module.emplace_back(/*num_bits=*/64); + } else { + TI_NOT_IMPLEMENTED; + } + } else { + TI_NOT_IMPLEMENTED; + } +} + +void TypeVisitor::visit_type(const tinyir::Type *type) { + if (type->is()) { + visit_physical_pointer_type(type->as()); + } else if (type->is()) { + visit_small_vector_type(type->as()); + } else if (type->is()) { + visit_array_type(type->as()); + } else if (type->is()) { + visit_struct_type(type->as()); + } else if (type->is()) { + visit_int_type(type->as()); + } else if (type->is()) { + visit_float_type(type->as()); + } +} + +class TypePrinter : public TypeVisitor { + private: + std::string result_; + STD140LayoutContext layout_ctx_; + + uint32_t head_{0}; + std::unordered_map idmap_; + + uint32_t get_id(const tinyir::Type *type) { + if (idmap_.find(type) == idmap_.end()) { + uint32_t id = head_++; + idmap_[type] = id; + return id; + } else { + return idmap_[type]; + } + } + + public: + void visit_int_type(const IntType *type) override { + result_ += fmt::format("T{} = {}int{}_t\n", get_id(type), + type->is_signed() ? "" : "u", type->num_bits()); + } + + void visit_float_type(const FloatType *type) override { + result_ += fmt::format("T{} = float{}_t\n", get_id(type), type->num_bits()); + } + + void visit_physical_pointer_type(const PhysicalPointerType *type) override { + result_ += fmt::format("T{} = T{} *\n", get_id(type), + get_id(type->get_pointed_type())); + } + + void visit_struct_type(const StructType *type) override { + result_ += fmt::format("T{} = struct {{", get_id(type)); + for (int i = 0; i < type->get_num_elements(); i++) { + result_ += fmt::format("T{}, ", get_id(type->nth_element_type(i))); + } + result_ += "}}\n"; + } + + void visit_small_vector_type(const SmallVectorType *type) override { + result_ += fmt::format("T{} = small_vector\n", get_id(type), + get_id(type->element_type()), + type->get_constant_shape()[0]); + } + + void visit_array_type(const ArrayType *type) override { + result_ += fmt::format("T{} = array\n", get_id(type), + get_id(type->element_type()), + type->get_constant_shape()[0]); + } + + static std::string print_types(const tinyir::Block *block) { + TypePrinter p; + p.visit(block); + return p.result_; + } +}; + +std::string ir_print_types(const tinyir::Block *block) { + return TypePrinter::print_types(block); +} + +class TypeReducer : public TypeVisitor { + public: + std::unique_ptr copy{nullptr}; + std::unordered_map &oldptr2newptr; + + TypeReducer( + std::unordered_map &old2new) + : oldptr2newptr(old2new) { + copy = std::make_unique(); + old2new.clear(); + } + + const tinyir::Type *check_type(const tinyir::Type *type) { + if (oldptr2newptr.find(type) != oldptr2newptr.end()) { + return oldptr2newptr[type]; + } + for (const auto &t : copy->nodes()) { + if (t->equals(type)) { + oldptr2newptr[type] = (const tinyir::Type *)t.get(); + return (const tinyir::Type *)t.get(); + } + } + return nullptr; + } + + void visit_int_type(const IntType *type) override { + if (!check_type(type)) { + oldptr2newptr[type] = copy->emplace_back(*type); + } + } + + void visit_float_type(const FloatType *type) override { + if (!check_type(type)) { + oldptr2newptr[type] = copy->emplace_back(*type); + } + } + + void visit_physical_pointer_type(const PhysicalPointerType *type) override { + if (!check_type(type)) { + const tinyir::Type *pointed = check_type(type->get_pointed_type()); + TI_ASSERT(pointed); + oldptr2newptr[type] = copy->emplace_back(pointed); + } + } + + void visit_struct_type(const StructType *type) override { + if (!check_type(type)) { + std::vector elements; + for (int i = 0; i < type->get_num_elements(); i++) { + const tinyir::Type *elm = check_type(type->nth_element_type(i)); + TI_ASSERT(elm); + elements.push_back(elm); + } + oldptr2newptr[type] = copy->emplace_back(elements); + } + } + + void visit_small_vector_type(const SmallVectorType *type) override { + if (!check_type(type)) { + const tinyir::Type *element = check_type(type->element_type()); + TI_ASSERT(element); + oldptr2newptr[type] = copy->emplace_back( + element, type->get_constant_shape()[0]); + } + } + + void visit_array_type(const ArrayType *type) override { + if (!check_type(type)) { + const tinyir::Type *element = check_type(type->element_type()); + TI_ASSERT(element); + oldptr2newptr[type] = + copy->emplace_back(element, type->get_constant_shape()[0]); + } + } +}; + +std::unique_ptr ir_reduce_types( + tinyir::Block *blk, + std::unordered_map &old2new) { + TypeReducer reducer(old2new); + reducer.visit(blk); + return std::move(reducer.copy); +} + +class Translate2Spirv : public TypeVisitor { + private: + IRBuilder *spir_builder_{nullptr}; + tinyir::LayoutContext &layout_context_; + + public: + std::unordered_map ir_node_2_spv_value; + + Translate2Spirv(IRBuilder *spir_builder, + tinyir::LayoutContext &layout_context) + : spir_builder_(spir_builder), layout_context_(layout_context) { + } + + void visit_int_type(const IntType *type) override { + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global(spv::OpTypeInt, vt, type->num_bits(), + type->is_signed() ? 1 : 0); + ir_node_2_spv_value[type] = vt.id; + } + + void visit_float_type(const FloatType *type) override { + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global(spv::OpTypeFloat, vt, type->num_bits()); + ir_node_2_spv_value[type] = vt.id; + } + + void visit_physical_pointer_type(const PhysicalPointerType *type) override { + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global( + spv::OpTypePointer, vt, spv::StorageClassPhysicalStorageBuffer, + ir_node_2_spv_value[type->get_pointed_type()]); + ir_node_2_spv_value[type] = vt.id; + } + + void visit_struct_type(const StructType *type) override { + std::vector element_ids; + for (int i = 0; i < type->get_num_elements(); i++) { + element_ids.push_back(ir_node_2_spv_value[type->nth_element_type(i)]); + } + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global(spv::OpTypeStruct, vt, element_ids); + ir_node_2_spv_value[type] = vt.id; + for (int i = 0; i < type->get_num_elements(); i++) { + spir_builder_->decorate(spv::OpMemberDecorate, vt, i, + spv::DecorationOffset, + type->nth_element_offset(i, layout_context_)); + } + } + + void visit_small_vector_type(const SmallVectorType *type) override { + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global(spv::OpTypeVector, vt, + ir_node_2_spv_value[type->element_type()], + type->get_constant_shape()[0]); + ir_node_2_spv_value[type] = vt.id; + } + + void visit_array_type(const ArrayType *type) override { + SType vt = spir_builder_->get_null_type(); + spir_builder_->declare_global(spv::OpTypeArray, vt, + ir_node_2_spv_value[type->element_type()], + type->get_constant_shape()[0]); + ir_node_2_spv_value[type] = vt.id; + spir_builder_->decorate(spv::OpDecorate, vt, spv::DecorationArrayStride, + type->memory_alignment_size(layout_context_)); + } +}; + +std::unordered_map ir_translate_to_spirv( + const tinyir::Block *blk, + tinyir::LayoutContext &layout_ctx, + IRBuilder *spir_builder) { + Translate2Spirv translator(spir_builder, layout_ctx); + translator.visit(blk); + return std::move(translator.ir_node_2_spv_value); +} + +} // namespace spirv +} // namespace lang +} // namespace taichi diff --git a/taichi/codegen/spirv/spirv_types.h b/taichi/codegen/spirv/spirv_types.h new file mode 100644 index 0000000000000..8870d697f6a55 --- /dev/null +++ b/taichi/codegen/spirv/spirv_types.h @@ -0,0 +1,245 @@ +#pragma once + +#include "lib_tiny_ir.h" +#include "taichi/ir/type.h" + +namespace taichi { +namespace lang { +namespace spirv { + +class STD140LayoutContext : public tinyir::LayoutContext {}; +class STD430LayoutContext : public tinyir::LayoutContext {}; + +class IntType : public tinyir::Type, public tinyir::MemRefElementTypeInterface { + public: + IntType(int num_bits, bool is_signed) + : num_bits_(num_bits), is_signed_(is_signed) { + } + + int num_bits() const { + return num_bits_; + } + + bool is_signed() const { + return is_signed_; + } + + size_t memory_size(tinyir::LayoutContext &ctx) const override { + return tinyir::ceil_div(num_bits(), 8); + } + + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override { + return tinyir::ceil_div(num_bits(), 8); + } + + private: + bool is_equal(const Polymorphic &other) const override { + const IntType &t = (const IntType &)other; + return t.num_bits_ == num_bits_ && t.is_signed_ == is_signed_; + } + + int num_bits_{0}; + bool is_signed_{false}; +}; + +class FloatType : public tinyir::Type, + public tinyir::MemRefElementTypeInterface { + public: + FloatType(int num_bits) : num_bits_(num_bits) { + } + + int num_bits() const { + return num_bits_; + } + + size_t memory_size(tinyir::LayoutContext &ctx) const override { + return tinyir::ceil_div(num_bits(), 8); + } + + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override { + return tinyir::ceil_div(num_bits(), 8); + } + + private: + int num_bits_{0}; + + bool is_equal(const Polymorphic &other) const override { + const FloatType &t = (const FloatType &)other; + return t.num_bits_ == num_bits_; + } +}; + +class PhysicalPointerType : public IntType, + public tinyir::PointerTypeInterface { + public: + PhysicalPointerType(const tinyir::Type *pointed_type) + : IntType(/*num_bits=*/64, /*is_signed=*/false), + pointed_type_(pointed_type) { + } + + const tinyir::Type *get_pointed_type() const override { + return pointed_type_; + } + + private: + const tinyir::Type *pointed_type_; + + bool is_equal(const Polymorphic &other) const override { + const PhysicalPointerType &pt = (const PhysicalPointerType &)other; + return IntType::operator==((const IntType &)other) && + pointed_type_->equals(pt.pointed_type_); + } +}; + +class StructType : public tinyir::Type, + public tinyir::AggregateTypeInterface, + public tinyir::MemRefAggregateTypeInterface { + public: + StructType(std::vector &elements) + : elements_(elements) { + } + + const tinyir::Type *nth_element_type(int n) const override { + return elements_[n]; + } + + int get_num_elements() const override { + return elements_.size(); + } + + size_t memory_size(tinyir::LayoutContext &ctx) const override; + + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override; + + size_t nth_element_offset(int n, tinyir::LayoutContext &ctx) const override; + + private: + std::vector elements_; + + bool is_equal(const Polymorphic &other) const override { + const StructType &t = (const StructType &)other; + if (t.get_num_elements() != get_num_elements()) { + return false; + } + for (int i = 0; i < get_num_elements(); i++) { + if (!elements_[i]->equals(t.elements_[i])) { + return false; + } + } + return true; + } +}; + +class SmallVectorType : public tinyir::Type, + public tinyir::ShapedTypeInterface, + public tinyir::MemRefElementTypeInterface { + public: + SmallVectorType(const tinyir::Type *element_type, int num_elements); + + const tinyir::Type *element_type() const override { + return element_type_; + } + + bool is_constant_shape() const override { + return true; + } + + std::vector get_constant_shape() const override { + return {size_t(num_elements_)}; + } + + size_t memory_size(tinyir::LayoutContext &ctx) const override; + + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override; + + private: + bool is_equal(const Polymorphic &other) const override { + const SmallVectorType &t = (const SmallVectorType &)other; + return num_elements_ == t.num_elements_ && + element_type_->equals(t.element_type_); + } + + const tinyir::Type *element_type_{nullptr}; + int num_elements_{0}; +}; + +class ArrayType : public tinyir::Type, + public tinyir::ShapedTypeInterface, + public tinyir::MemRefAggregateTypeInterface { + public: + ArrayType(const tinyir::Type *element_type, size_t size) + : element_type_(element_type), size_(size) { + } + + const tinyir::Type *element_type() const override { + return element_type_; + } + + bool is_constant_shape() const override { + return true; + } + + std::vector get_constant_shape() const override { + return {size_}; + } + + size_t memory_size(tinyir::LayoutContext &ctx) const override; + + size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override; + + size_t nth_element_offset(int n, tinyir::LayoutContext &ctx) const override; + + private: + bool is_equal(const Polymorphic &other) const override { + const ArrayType &t = (const ArrayType &)other; + return size_ == t.size_ && element_type_->equals(t.element_type_); + } + + const tinyir::Type *element_type_{nullptr}; + size_t size_{0}; +}; + +bool bitcast_possible(tinyir::Type *a, tinyir::Type *b, bool _inverted = false); + +class TypeVisitor : public tinyir::Visitor { + public: + void visit_type(const tinyir::Type *type) override; + + virtual void visit_int_type(const IntType *type) { + } + + virtual void visit_float_type(const FloatType *type) { + } + + virtual void visit_physical_pointer_type(const PhysicalPointerType *type) { + } + + virtual void visit_struct_type(const StructType *type) { + } + + virtual void visit_small_vector_type(const SmallVectorType *type) { + } + + virtual void visit_array_type(const ArrayType *type) { + } +}; + +const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module, + const DataType t); + +std::string ir_print_types(const tinyir::Block *block); + +std::unique_ptr ir_reduce_types( + tinyir::Block *blk, + std::unordered_map &old2new); + +class IRBuilder; + +std::unordered_map ir_translate_to_spirv( + const tinyir::Block *blk, + tinyir::LayoutContext &layout_ctx, + IRBuilder *spir_builder); + +} // namespace spirv +} // namespace lang +} // namespace taichi diff --git a/taichi/runtime/gfx/runtime.cpp b/taichi/runtime/gfx/runtime.cpp index c62903e65179b..460b4d931e131 100644 --- a/taichi/runtime/gfx/runtime.cpp +++ b/taichi/runtime/gfx/runtime.cpp @@ -64,12 +64,15 @@ class HostDeviceContextBlitter { RuntimeContext::DevAllocType::kNone && ext_arr_size.at(i)) { // Only need to blit ext arrs (host array) - DeviceAllocation buffer = ext_arrays.at(i); - char *const device_arr_ptr = - reinterpret_cast(device_->map(buffer)); - const void *host_ptr = host_ctx_->get_arg(i); - std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(i)); - device_->unmap(buffer); + uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i)); + if (access & uint32_t(irpass::ExternalPtrAccess::READ)) { + DeviceAllocation buffer = ext_arrays.at(i); + char *const device_arr_ptr = + reinterpret_cast(device_->map(buffer)); + const void *host_ptr = host_ctx_->get_arg(i); + std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(i)); + device_->unmap(buffer); + } } // Substitue in the device address if supported if ((host_ctx_->device_allocation_type[i] == @@ -125,7 +128,7 @@ class HostDeviceContextBlitter { bool device_to_host( CommandList *cmdlist, - const std::unordered_map &ext_arrays, + const std::unordered_map &ext_array_shadows, const std::unordered_map &ext_arr_size, const std::vector &wait_semaphore) { if (ctx_attribs_->empty()) { @@ -159,12 +162,15 @@ class HostDeviceContextBlitter { RuntimeContext::DevAllocType::kNone && ext_arr_size.at(i)) { // Only need to blit ext arrs (host array) - DeviceAllocation buffer = ext_arrays.at(i); - char *const device_arr_ptr = - reinterpret_cast(device_->map(buffer)); - void *host_ptr = host_ctx_->get_arg(i); - std::memcpy(host_ptr, device_arr_ptr, ext_arr_size.at(i)); - device_->unmap(buffer); + uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i)); + if (access & uint32_t(irpass::ExternalPtrAccess::WRITE)) { + DeviceAllocation buffer = ext_array_shadows.at(i); + char *const device_arr_ptr = + reinterpret_cast(device_->map(buffer)); + void *host_ptr = host_ctx_->get_arg(i); + std::memcpy(host_ptr, device_arr_ptr, ext_arr_size.at(i)); + device_->unmap(buffer); + } } } } @@ -315,60 +321,8 @@ size_t CompiledTaichiKernel::get_ret_buffer_size() const { return ret_buffer_size_; } -void CompiledTaichiKernel::generate_command_list( - CommandList *cmdlist, - DeviceAllocationGuard *args_buffer, - DeviceAllocationGuard *ret_buffer, - const std::unordered_map &ext_arrs, - const std::unordered_map &textures) const { - const auto &task_attribs = ti_kernel_attribs_.tasks_attribs; - - for (int i = 0; i < task_attribs.size(); ++i) { - const auto &attribs = task_attribs[i]; - auto vp = pipelines_[i].get(); - const int group_x = (attribs.advisory_total_num_threads + - attribs.advisory_num_threads_per_group - 1) / - attribs.advisory_num_threads_per_group; - ResourceBinder *binder = vp->resource_binder(); - for (auto &bind : attribs.buffer_binds) { - if (bind.buffer.type == BufferType::ExtArr) { - binder->rw_buffer(0, bind.binding, ext_arrs.at(bind.buffer.root_id)); - } else if (args_buffer && bind.buffer.type == BufferType::Args) { - binder->buffer(0, bind.binding, *args_buffer); - } else if (ret_buffer && bind.buffer.type == BufferType::Rets) { - binder->rw_buffer(0, bind.binding, *ret_buffer); - } else { - DeviceAllocation *alloc = input_buffers_.at(bind.buffer); - if (alloc) { - binder->rw_buffer(0, bind.binding, *alloc); - } - } - } - - for (auto &bind : attribs.texture_binds) { - DeviceAllocation texture = textures.at(bind.arg_id); - cmdlist->image_transition(texture, ImageLayout::undefined, - ImageLayout::shader_read); - binder->image(0, bind.binding, texture, {}); - } - - if (attribs.task_type == OffloadedTaskType::listgen) { - for (auto &bind : attribs.buffer_binds) { - if (bind.buffer.type == BufferType::ListGen) { - // FIXME: properlly support multiple list - cmdlist->buffer_fill(input_buffers_.at(bind.buffer)->get_ptr(0), - kBufferSizeEntireSize, - /*data=*/0); - cmdlist->buffer_barrier(*input_buffers_.at(bind.buffer)); - } - } - } - - cmdlist->bind_pipeline(vp); - cmdlist->bind_resources(binder); - cmdlist->dispatch(group_x); - cmdlist->memory_barrier(); - } +Pipeline *CompiledTaichiKernel::get_pipeline(int i) { + return pipelines_[i].get(); } GfxRuntime::GfxRuntime(const Params ¶ms) @@ -441,6 +395,7 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { // `any_arrays` contain both external arrays and NDArrays std::vector> allocated_buffers; std::unordered_map any_arrays; + std::unordered_map any_array_shadows; // `ext_array_size` only holds the size of external arrays (host arrays) // As buffer size information is only needed when it needs to be allocated // and transferred by the host @@ -476,13 +431,28 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { } } else { ext_array_size[i] = host_ctx->array_runtime_sizes[i]; + uint32_t access = uint32_t( + ti_kernel->ti_kernel_attribs().ctx_attribs.arr_access.at(i)); + // Alloc ext arr if (ext_array_size[i]) { + bool host_write = + access & uint32_t(irpass::ExternalPtrAccess::READ); auto allocated = device_->allocate_memory_unique( - {ext_array_size[i], /*host_write=*/true, /*host_read=*/true, + {ext_array_size[i], host_write, false, /*export_sharing=*/false, AllocUsage::Storage}); any_arrays[i] = *allocated.get(); allocated_buffers.push_back(std::move(allocated)); + + bool host_read = + access & uint32_t(irpass::ExternalPtrAccess::WRITE); + if (host_read) { + auto allocated = device_->allocate_memory_unique( + {ext_array_size[i], false, true, + /*export_sharing=*/false, AllocUsage::Storage}); + any_array_shadows[i] = *allocated.get(); + allocated_buffers.push_back(std::move(allocated)); + } } else { any_arrays[i] = kDeviceNullAllocation; } @@ -502,8 +472,61 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { } // Record commands - ti_kernel->generate_command_list(current_cmdlist_.get(), args_buffer.get(), - ret_buffer.get(), any_arrays, textures); + const auto &task_attribs = ti_kernel->ti_kernel_attribs().tasks_attribs; + + for (int i = 0; i < task_attribs.size(); ++i) { + const auto &attribs = task_attribs[i]; + auto vp = ti_kernel->get_pipeline(i); + const int group_x = (attribs.advisory_total_num_threads + + attribs.advisory_num_threads_per_group - 1) / + attribs.advisory_num_threads_per_group; + ResourceBinder *binder = vp->resource_binder(); + for (auto &bind : attribs.buffer_binds) { + if (bind.buffer.type == BufferType::ExtArr) { + binder->rw_buffer(0, bind.binding, any_arrays.at(bind.buffer.root_id)); + } else if (args_buffer && bind.buffer.type == BufferType::Args) { + binder->buffer(0, bind.binding, *args_buffer); + } else if (ret_buffer && bind.buffer.type == BufferType::Rets) { + binder->rw_buffer(0, bind.binding, *ret_buffer); + } else { + DeviceAllocation *alloc = ti_kernel->get_buffer_bind(bind.buffer); + if (alloc) { + binder->rw_buffer(0, bind.binding, *alloc); + } + } + } + + for (auto &bind : attribs.texture_binds) { + DeviceAllocation texture = textures.at(bind.arg_id); + current_cmdlist_->image_transition(texture, ImageLayout::undefined, + ImageLayout::shader_read); + binder->image(0, bind.binding, texture, {}); + } + + if (attribs.task_type == OffloadedTaskType::listgen) { + for (auto &bind : attribs.buffer_binds) { + if (bind.buffer.type == BufferType::ListGen) { + // FIXME: properlly support multiple list + current_cmdlist_->buffer_fill( + ti_kernel->get_buffer_bind(bind.buffer)->get_ptr(0), + kBufferSizeEntireSize, + /*data=*/0); + current_cmdlist_->buffer_barrier( + *ti_kernel->get_buffer_bind(bind.buffer)); + } + } + } + + current_cmdlist_->bind_pipeline(vp); + current_cmdlist_->bind_resources(binder); + current_cmdlist_->dispatch(group_x); + current_cmdlist_->memory_barrier(); + } + + for (auto &[id, shadow] : any_array_shadows) { + current_cmdlist_->buffer_copy( + shadow.get_ptr(0), any_arrays.at(id).get_ptr(0), ext_array_size.at(id)); + } // Keep context buffers used in this dispatch if (ti_kernel->get_args_buffer_size()) { @@ -517,7 +540,7 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { std::vector wait_semaphore; if (ctx_blitter) { - if (ctx_blitter->device_to_host(current_cmdlist_.get(), any_arrays, + if (ctx_blitter->device_to_host(current_cmdlist_.get(), any_array_shadows, ext_array_size, wait_semaphore)) { current_cmdlist_ = nullptr; ctx_buffers_.clear(); diff --git a/taichi/runtime/gfx/runtime.h b/taichi/runtime/gfx/runtime.h index 6c7cbc8a50563..844d5084e6472 100644 --- a/taichi/runtime/gfx/runtime.h +++ b/taichi/runtime/gfx/runtime.h @@ -54,12 +54,11 @@ class CompiledTaichiKernel { size_t get_args_buffer_size() const; size_t get_ret_buffer_size() const; - void generate_command_list( - CommandList *cmdlist, - DeviceAllocationGuard *args_buffer, - DeviceAllocationGuard *ret_buffer, - const std::unordered_map &ext_arrs, - const std::unordered_map &textures) const; + Pipeline *get_pipeline(int i); + + DeviceAllocation *get_buffer_bind(const BufferInfo &bind) { + return input_buffers_[bind]; + } private: TaichiKernelAttributes ti_kernel_attribs_; diff --git a/tests/cpp/aot/aot_save_load_test.cpp b/tests/cpp/aot/aot_save_load_test.cpp index 78d8d16eb0939..5a5998c0a3758 100644 --- a/tests/cpp/aot/aot_save_load_test.cpp +++ b/tests/cpp/aot/aot_save_load_test.cpp @@ -268,6 +268,7 @@ TEST(AotSaveLoad, VulkanNdarray) { const int size = 10; taichi::lang::Device::AllocParams alloc_params; alloc_params.host_write = true; + alloc_params.host_read = true; alloc_params.size = size * sizeof(int); alloc_params.usage = taichi::lang::AllocUsage::Storage; DeviceAllocation devalloc_arr_ =