diff --git a/taichi/backends/vulkan/aot_module_builder_impl.cpp b/taichi/backends/vulkan/aot_module_builder_impl.cpp index c57d785d7d82b..60cc5f7aa266f 100644 --- a/taichi/backends/vulkan/aot_module_builder_impl.cpp +++ b/taichi/backends/vulkan/aot_module_builder_impl.cpp @@ -46,12 +46,12 @@ class AotDataConverter { for (const auto &arg : in.ctx_attribs.args()) { if (!arg.is_array) { aot::ScalarArg scalar_arg{}; - scalar_arg.dtype_name = arg.dt.to_string(); + scalar_arg.dtype_name = PrimitiveType::get(arg.dtype).to_string(); scalar_arg.offset_in_args_buf = arg.offset_in_mem; res.scalar_args[arg.index] = scalar_arg; } else { aot::ArrayArg arr_arg{}; - arr_arg.dtype_name = arg.dt.to_string(); + arr_arg.dtype_name = PrimitiveType::get(arg.dtype).to_string(); arr_arg.field_dim = arg.field_dim; arr_arg.element_shape = arg.element_shape; arr_arg.shape_offset_in_args_buf = arg.index * sizeof(int32_t); @@ -105,32 +105,6 @@ AotModuleBuilderImpl::AotModuleBuilderImpl( } } -uint32_t AotModuleBuilderImpl::to_vk_dtype_enum(DataType dt) { - if (dt == PrimitiveType::u64) { - return 0; - } else if (dt == PrimitiveType::i64) { - return 1; - } else if (dt == PrimitiveType::u32) { - return 2; - } else if (dt == PrimitiveType::i32) { - return 3; - } else if (dt == PrimitiveType::u16) { - return 4; - } else if (dt == PrimitiveType::i16) { - return 5; - } else if (dt == PrimitiveType::u8) { - return 6; - } else if (dt == PrimitiveType::i8) { - return 7; - } else if (dt == PrimitiveType::f64) { - return 8; - } else if (dt == PrimitiveType::f32) { - return 9; - } else { - TI_NOT_IMPLEMENTED - } -} - std::string AotModuleBuilderImpl::write_spv_file( const std::string &output_dir, const TaskAttributes &k, @@ -194,7 +168,7 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, aot::CompiledFieldData field_data; field_data.field_name = identifier; field_data.is_scalar = is_scalar; - field_data.dtype = to_vk_dtype_enum(dt); + field_data.dtype = static_cast(dt->cast()->type); field_data.dtype_name = dt.to_string(); field_data.shape = shape; field_data.mem_offset_in_parent = dense_desc.mem_offset_in_parent_cell; diff --git a/taichi/backends/vulkan/aot_module_builder_impl.h b/taichi/backends/vulkan/aot_module_builder_impl.h index 408e24b1f3dc8..bbd6b40e4df48 100644 --- a/taichi/backends/vulkan/aot_module_builder_impl.h +++ b/taichi/backends/vulkan/aot_module_builder_impl.h @@ -40,8 +40,6 @@ class AotModuleBuilderImpl : public AotModuleBuilder { const TaskAttributes &k, const std::vector &source_code) const; - uint32_t to_vk_dtype_enum(DataType dt); - const std::vector &compiled_structs_; TaichiAotData ti_aot_data_; std::unique_ptr aot_target_device_; diff --git a/taichi/backends/vulkan/runtime.cpp b/taichi/backends/vulkan/runtime.cpp index 410c40f552151..e500e8921bc87 100644 --- a/taichi/backends/vulkan/runtime.cpp +++ b/taichi/backends/vulkan/runtime.cpp @@ -67,16 +67,15 @@ class HostDeviceContextBlitter { char *const device_base = reinterpret_cast(device_->map(*device_args_buffer_)); -#define TO_DEVICE(short_type, type) \ - if (dt->is_primitive(PrimitiveTypeID::short_type)) { \ - auto d = host_ctx_->get_arg(i); \ - reinterpret_cast(device_ptr)[0] = d; \ - break; \ +#define TO_DEVICE(short_type, type) \ + if (arg.dtype == PrimitiveTypeID::short_type) { \ + auto d = host_ctx_->get_arg(i); \ + reinterpret_cast(device_ptr)[0] = d; \ + break; \ } for (int i = 0; i < ctx_attribs_->args().size(); ++i) { const auto &arg = ctx_attribs_->args()[i]; - const auto dt = arg.dt; char *device_ptr = device_base + arg.offset_in_mem; do { if (arg.is_array) { @@ -118,13 +117,14 @@ class HostDeviceContextBlitter { TO_DEVICE(f64, float64) } if (device_->get_cap(DeviceCapability::spirv_has_float16)) { - if (dt->is_primitive(PrimitiveTypeID::f16)) { + if (arg.dtype == PrimitiveTypeID::f16) { auto d = fp16_ieee_from_fp32_value(host_ctx_->get_arg(i)); reinterpret_cast(device_ptr)[0] = d; break; } } - TI_ERROR("Vulkan does not support arg type={}", data_type_name(arg.dt)); + TI_ERROR("Vulkan does not support arg type={}", + PrimitiveType::get(arg.dtype).to_string()); } while (0); } @@ -196,8 +196,8 @@ class HostDeviceContextBlitter { // *arg* on the host context. const auto &ret = ctx_attribs_->rets()[i]; char *device_ptr = device_base + ret.offset_in_mem; - const auto dt = ret.dt; - const auto num = ret.stride / data_type_size(ret.dt); + const auto dt = PrimitiveType::get(ret.dtype); + const auto num = ret.stride / data_type_size(dt); for (int j = 0; j < num; ++j) { if (device_->get_cap(DeviceCapability::spirv_has_int8)) { TO_HOST(i8, int8, j) @@ -227,7 +227,7 @@ class HostDeviceContextBlitter { } } TI_ERROR("Vulkan does not support return value type={}", - data_type_name(ret.dt)); + data_type_name(PrimitiveType::get(ret.dtype))); } } #undef TO_HOST diff --git a/taichi/codegen/spirv/kernel_utils.cpp b/taichi/codegen/spirv/kernel_utils.cpp index b29e03b460372..759061a6823cf 100644 --- a/taichi/codegen/spirv/kernel_utils.cpp +++ b/taichi/codegen/spirv/kernel_utils.cpp @@ -53,10 +53,13 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) rets_bytes_(0), extra_args_bytes_(RuntimeContext::extra_args_size) { arg_attribs_vec_.reserve(kernel.args.size()); + // TODO: We should be able to limit Kernel args and rets to be primitive types + // as well but let's leave that as a followup up PR. for (const auto &ka : kernel.args) { ArgAttributes aa; - aa.dt = ka.dt; - const size_t dt_bytes = data_type_size(aa.dt); + TI_ASSERT(ka.dt->is()); + aa.dtype = ka.dt->cast()->type; + const size_t dt_bytes = data_type_size(ka.dt); aa.is_array = ka.is_array; if (aa.is_array) { aa.field_dim = ka.total_dim - ka.element_shape.size(); @@ -70,13 +73,16 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) RetAttributes ra; size_t dt_bytes{0}; if (auto tensor_type = kr.dt->cast()) { - ra.dt = tensor_type->get_element_type(); - dt_bytes = data_type_size(ra.dt); + auto tensor_dtype = tensor_type->get_element_type(); + TI_ASSERT(tensor_dtype->is()); + ra.dtype = tensor_dtype->cast()->type; + dt_bytes = data_type_size(tensor_dtype); ra.is_array = true; ra.stride = tensor_type->get_num_elements() * dt_bytes; } else { - ra.dt = kr.dt; - dt_bytes = data_type_size(ra.dt); + TI_ASSERT(kr.dt->is()); + ra.dtype = kr.dt->cast()->type; + dt_bytes = data_type_size(kr.dt); ra.is_array = false; ra.stride = dt_bytes; } @@ -88,9 +94,10 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) size_t bytes = offset; for (int i = 0; i < vec->size(); ++i) { auto &attribs = (*vec)[i]; - const size_t dt_bytes = (attribs.is_array && !is_ret) - ? sizeof(uint64_t) - : data_type_size(attribs.dt); + const size_t dt_bytes = + (attribs.is_array && !is_ret) + ? sizeof(uint64_t) + : data_type_size(PrimitiveType::get(attribs.dtype)); // Align bytes to the nearest multiple of dt_bytes bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes; attribs.offset_in_mem = bytes; diff --git a/taichi/codegen/spirv/kernel_utils.h b/taichi/codegen/spirv/kernel_utils.h index 21d528dba12c3..95a8aa71196e3 100644 --- a/taichi/codegen/spirv/kernel_utils.h +++ b/taichi/codegen/spirv/kernel_utils.h @@ -140,12 +140,18 @@ class KernelContextAttributes { size_t offset_in_mem{0}; // Index of the input arg or the return value in the host `Context` int index{-1}; - DataType dt; + PrimitiveTypeID dtype{PrimitiveTypeID::unknown}; bool is_array{false}; std::vector element_shape; std::size_t field_dim{0}; - TI_IO_DEF(stride, offset_in_mem, index, is_array, element_shape, field_dim); + TI_IO_DEF(stride, + offset_in_mem, + index, + dtype, + is_array, + element_shape, + field_dim); }; public: diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 821ccb12b91f6..4c5adb7faed85 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -502,7 +502,7 @@ class TaskCodegen : public IRVisitor { // ir_->int_immediate_number(ir_->i32_type(), offset_in_mem); // ir_->register_value(stmt->raw_name(), val); } else { - const auto dt = arg_attribs.dt; + const auto dt = PrimitiveType::get(arg_attribs.dtype); const auto val_type = ir_->get_primitive_type(dt); spirv::Value buffer_val = ir_->make_value( spv::OpAccessChain, @@ -1806,9 +1806,9 @@ class TaskCodegen : public IRVisitor { "arg_ptr" + std::to_string(arg.index), arg.offset_in_mem); } else { - struct_components_.emplace_back(ir_->get_primitive_type(arg.dt), - "arg" + std::to_string(arg.index), - arg.offset_in_mem); + struct_components_.emplace_back( + ir_->get_primitive_type(PrimitiveType::get(arg.dtype)), + "arg" + std::to_string(arg.index), arg.offset_in_mem); } } // A compromise for use in constants buffer @@ -1833,7 +1833,8 @@ class TaskCodegen : public IRVisitor { // Now we only have one ret TI_ASSERT(ctx_attribs_->rets().size() == 1); for (auto &ret : ctx_attribs_->rets()) { - if (auto tensor_type = ret.dt->cast()) { + if (auto tensor_type = + PrimitiveType::get(ret.dtype)->cast()) { struct_components_.emplace_back( ir_->get_array_type( ir_->get_primitive_type(tensor_type->get_element_type()), @@ -1841,7 +1842,8 @@ class TaskCodegen : public IRVisitor { "ret" + std::to_string(ret.index), ret.offset_in_mem); } else { struct_components_.emplace_back( - ir_->get_array_type(ir_->get_primitive_type(ret.dt), 1), + ir_->get_array_type( + ir_->get_primitive_type(PrimitiveType::get(ret.dtype)), 1), "ret" + std::to_string(ret.index), ret.offset_in_mem); } } diff --git a/taichi/program/context.h b/taichi/program/context.h index 0d6d7eef413a6..4d7562054f53a 100644 --- a/taichi/program/context.h +++ b/taichi/program/context.h @@ -50,6 +50,11 @@ struct RuntimeContext { void set_device_allocation(int i, bool is_device_allocation_) { is_device_allocation[i] = is_device_allocation_; } + + template + T get_ret(int i) { + return taichi_union_cast_with_different_sizes(result_buffer[i]); + } #endif }; diff --git a/tests/cpp/aot/aot_save_load_test.cpp b/tests/cpp/aot/aot_save_load_test.cpp index 6b66213121e00..798af91e54926 100644 --- a/tests/cpp/aot/aot_save_load_test.cpp +++ b/tests/cpp/aot/aot_save_load_test.cpp @@ -1,6 +1,7 @@ #include "gtest/gtest.h" #include "taichi/ir/ir_builder.h" #include "taichi/ir/statements.h" +#include "taichi/inc/constants.h" #include "taichi/program/program.h" #ifdef TI_WITH_VULKAN #include "taichi/backends/vulkan/aot_module_loader_impl.h" @@ -29,7 +30,24 @@ using namespace lang; auto aot_builder = program.make_aot_module_builder(Arch::vulkan); - std::unique_ptr kernel_init, kernel_ret; + std::unique_ptr kernel_init, kernel_ret, kernel_simple_ret; + + { + /* + @ti.kernel + def ret() -> ti.f32: + sum = 0.2 + return sum + */ + IRBuilder builder; + auto *sum = builder.create_local_var(PrimitiveType::f32); + builder.create_local_store(sum, builder.get_float32(0.2)); + builder.create_return(builder.create_local_load(sum)); + + kernel_simple_ret = + std::make_unique(program, builder.extract_ir(), "simple_ret"); + kernel_simple_ret->insert_ret(PrimitiveType::f32); + } { /* @@ -79,6 +97,7 @@ using namespace lang; kernel_ret->insert_ret(PrimitiveType::i32); } + aot_builder->add("simple_ret", kernel_simple_ret.get()); aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1); aot_builder->add("init", kernel_init.get()); aot_builder->add("ret", kernel_ret.get()); @@ -103,6 +122,7 @@ TEST(AotSaveLoad, Vulkan) { std::make_unique(Arch::vulkan, nullptr); result_buffer = (taichi::uint64 *)memory_pool->allocate( sizeof(taichi::uint64) * taichi_result_buffer_entries, 8); + host_ctx.result_buffer = result_buffer; // Create Taichi Device for computation lang::vulkan::VulkanDeviceCreator::Params evd_params; @@ -132,6 +152,13 @@ TEST(AotSaveLoad, Vulkan) { EXPECT_EQ(root_size, 64); vulkan_runtime->add_root_buffer(root_size); + auto simple_ret_kernel = vk_module->get_kernel("simple_ret"); + EXPECT_TRUE(simple_ret_kernel); + + simple_ret_kernel->launch(&host_ctx); + vulkan_runtime->synchronize(); + EXPECT_FLOAT_EQ(host_ctx.get_ret(0), 0.2); + auto init_kernel = vk_module->get_kernel("init"); EXPECT_TRUE(init_kernel);