Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aot] Support return in vulkan aot #4593

Merged
merged 1 commit into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 3 additions & 29 deletions taichi/backends/vulkan/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ class AotDataConverter {
for (const auto &arg : in.ctx_attribs.args()) {
if (!arg.is_array) {
aot::ScalarArg scalar_arg{};
scalar_arg.dtype_name = arg.dt.to_string();
scalar_arg.dtype_name = PrimitiveType::get(arg.dtype).to_string();
scalar_arg.offset_in_args_buf = arg.offset_in_mem;
res.scalar_args[arg.index] = scalar_arg;
} else {
aot::ArrayArg arr_arg{};
arr_arg.dtype_name = arg.dt.to_string();
arr_arg.dtype_name = PrimitiveType::get(arg.dtype).to_string();
arr_arg.field_dim = arg.field_dim;
arr_arg.element_shape = arg.element_shape;
arr_arg.shape_offset_in_args_buf = arg.index * sizeof(int32_t);
Expand Down Expand Up @@ -105,32 +105,6 @@ AotModuleBuilderImpl::AotModuleBuilderImpl(
}
}

uint32_t AotModuleBuilderImpl::to_vk_dtype_enum(DataType dt) {
if (dt == PrimitiveType::u64) {
return 0;
} else if (dt == PrimitiveType::i64) {
return 1;
} else if (dt == PrimitiveType::u32) {
return 2;
} else if (dt == PrimitiveType::i32) {
return 3;
} else if (dt == PrimitiveType::u16) {
return 4;
} else if (dt == PrimitiveType::i16) {
return 5;
} else if (dt == PrimitiveType::u8) {
return 6;
} else if (dt == PrimitiveType::i8) {
return 7;
} else if (dt == PrimitiveType::f64) {
return 8;
} else if (dt == PrimitiveType::f32) {
return 9;
} else {
TI_NOT_IMPLEMENTED
}
}

std::string AotModuleBuilderImpl::write_spv_file(
const std::string &output_dir,
const TaskAttributes &k,
Expand Down Expand Up @@ -194,7 +168,7 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
aot::CompiledFieldData field_data;
field_data.field_name = identifier;
field_data.is_scalar = is_scalar;
field_data.dtype = to_vk_dtype_enum(dt);
field_data.dtype = static_cast<int>(dt->cast<PrimitiveType>()->type);
field_data.dtype_name = dt.to_string();
field_data.shape = shape;
field_data.mem_offset_in_parent = dense_desc.mem_offset_in_parent_cell;
Expand Down
2 changes: 0 additions & 2 deletions taichi/backends/vulkan/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const TaskAttributes &k,
const std::vector<uint32_t> &source_code) const;

uint32_t to_vk_dtype_enum(DataType dt);

const std::vector<CompiledSNodeStructs> &compiled_structs_;
TaichiAotData ti_aot_data_;
std::unique_ptr<Device> aot_target_device_;
Expand Down
22 changes: 11 additions & 11 deletions taichi/backends/vulkan/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,15 @@ class HostDeviceContextBlitter {
char *const device_base =
reinterpret_cast<char *>(device_->map(*device_args_buffer_));

#define TO_DEVICE(short_type, type) \
if (dt->is_primitive(PrimitiveTypeID::short_type)) { \
auto d = host_ctx_->get_arg<type>(i); \
reinterpret_cast<type *>(device_ptr)[0] = d; \
break; \
#define TO_DEVICE(short_type, type) \
if (arg.dtype == PrimitiveTypeID::short_type) { \
auto d = host_ctx_->get_arg<type>(i); \
reinterpret_cast<type *>(device_ptr)[0] = d; \
break; \
}

for (int i = 0; i < ctx_attribs_->args().size(); ++i) {
const auto &arg = ctx_attribs_->args()[i];
const auto dt = arg.dt;
char *device_ptr = device_base + arg.offset_in_mem;
do {
if (arg.is_array) {
Expand Down Expand Up @@ -118,13 +117,14 @@ class HostDeviceContextBlitter {
TO_DEVICE(f64, float64)
}
if (device_->get_cap(DeviceCapability::spirv_has_float16)) {
if (dt->is_primitive(PrimitiveTypeID::f16)) {
if (arg.dtype == PrimitiveTypeID::f16) {
auto d = fp16_ieee_from_fp32_value(host_ctx_->get_arg<float>(i));
reinterpret_cast<uint16 *>(device_ptr)[0] = d;
break;
}
}
TI_ERROR("Vulkan does not support arg type={}", data_type_name(arg.dt));
TI_ERROR("Vulkan does not support arg type={}",
PrimitiveType::get(arg.dtype).to_string());
} while (0);
}

Expand Down Expand Up @@ -196,8 +196,8 @@ class HostDeviceContextBlitter {
// *arg* on the host context.
const auto &ret = ctx_attribs_->rets()[i];
char *device_ptr = device_base + ret.offset_in_mem;
const auto dt = ret.dt;
const auto num = ret.stride / data_type_size(ret.dt);
const auto dt = PrimitiveType::get(ret.dtype);
const auto num = ret.stride / data_type_size(dt);
for (int j = 0; j < num; ++j) {
if (device_->get_cap(DeviceCapability::spirv_has_int8)) {
TO_HOST(i8, int8, j)
Expand Down Expand Up @@ -227,7 +227,7 @@ class HostDeviceContextBlitter {
}
}
TI_ERROR("Vulkan does not support return value type={}",
data_type_name(ret.dt));
data_type_name(PrimitiveType::get(ret.dtype)));
}
}
#undef TO_HOST
Expand Down
25 changes: 16 additions & 9 deletions taichi/codegen/spirv/kernel_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel)
rets_bytes_(0),
extra_args_bytes_(RuntimeContext::extra_args_size) {
arg_attribs_vec_.reserve(kernel.args.size());
// TODO: We should be able to limit Kernel args and rets to be primitive types
// as well but let's leave that as a followup up PR.
for (const auto &ka : kernel.args) {
ArgAttributes aa;
aa.dt = ka.dt;
const size_t dt_bytes = data_type_size(aa.dt);
TI_ASSERT(ka.dt->is<PrimitiveType>());
aa.dtype = ka.dt->cast<PrimitiveType>()->type;
const size_t dt_bytes = data_type_size(ka.dt);
aa.is_array = ka.is_array;
if (aa.is_array) {
aa.field_dim = ka.total_dim - ka.element_shape.size();
Expand All @@ -70,13 +73,16 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel)
RetAttributes ra;
size_t dt_bytes{0};
if (auto tensor_type = kr.dt->cast<TensorType>()) {
ra.dt = tensor_type->get_element_type();
dt_bytes = data_type_size(ra.dt);
auto tensor_dtype = tensor_type->get_element_type();
TI_ASSERT(tensor_dtype->is<PrimitiveType>());
ra.dtype = tensor_dtype->cast<PrimitiveType>()->type;
dt_bytes = data_type_size(tensor_dtype);
ra.is_array = true;
ra.stride = tensor_type->get_num_elements() * dt_bytes;
} else {
ra.dt = kr.dt;
dt_bytes = data_type_size(ra.dt);
TI_ASSERT(kr.dt->is<PrimitiveType>());
ra.dtype = kr.dt->cast<PrimitiveType>()->type;
dt_bytes = data_type_size(kr.dt);
ra.is_array = false;
ra.stride = dt_bytes;
}
Expand All @@ -88,9 +94,10 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel)
size_t bytes = offset;
for (int i = 0; i < vec->size(); ++i) {
auto &attribs = (*vec)[i];
const size_t dt_bytes = (attribs.is_array && !is_ret)
? sizeof(uint64_t)
: data_type_size(attribs.dt);
const size_t dt_bytes =
(attribs.is_array && !is_ret)
? sizeof(uint64_t)
: data_type_size(PrimitiveType::get(attribs.dtype));
// Align bytes to the nearest multiple of dt_bytes
bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes;
attribs.offset_in_mem = bytes;
Expand Down
10 changes: 8 additions & 2 deletions taichi/codegen/spirv/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,18 @@ class KernelContextAttributes {
size_t offset_in_mem{0};
// Index of the input arg or the return value in the host `Context`
int index{-1};
DataType dt;
PrimitiveTypeID dtype{PrimitiveTypeID::unknown};
bool is_array{false};
std::vector<int> element_shape;
std::size_t field_dim{0};

TI_IO_DEF(stride, offset_in_mem, index, is_array, element_shape, field_dim);
TI_IO_DEF(stride,
offset_in_mem,
index,
dtype,
is_array,
element_shape,
field_dim);
};

public:
Expand Down
14 changes: 8 additions & 6 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class TaskCodegen : public IRVisitor {
// ir_->int_immediate_number(ir_->i32_type(), offset_in_mem);
// ir_->register_value(stmt->raw_name(), val);
} else {
const auto dt = arg_attribs.dt;
const auto dt = PrimitiveType::get(arg_attribs.dtype);
const auto val_type = ir_->get_primitive_type(dt);
spirv::Value buffer_val = ir_->make_value(
spv::OpAccessChain,
Expand Down Expand Up @@ -1806,9 +1806,9 @@ class TaskCodegen : public IRVisitor {
"arg_ptr" + std::to_string(arg.index),
arg.offset_in_mem);
} else {
struct_components_.emplace_back(ir_->get_primitive_type(arg.dt),
"arg" + std::to_string(arg.index),
arg.offset_in_mem);
struct_components_.emplace_back(
ir_->get_primitive_type(PrimitiveType::get(arg.dtype)),
"arg" + std::to_string(arg.index), arg.offset_in_mem);
}
}
// A compromise for use in constants buffer
Expand All @@ -1833,15 +1833,17 @@ class TaskCodegen : public IRVisitor {
// Now we only have one ret
TI_ASSERT(ctx_attribs_->rets().size() == 1);
for (auto &ret : ctx_attribs_->rets()) {
if (auto tensor_type = ret.dt->cast<TensorType>()) {
if (auto tensor_type =
PrimitiveType::get(ret.dtype)->cast<TensorType>()) {
struct_components_.emplace_back(
ir_->get_array_type(
ir_->get_primitive_type(tensor_type->get_element_type()),
tensor_type->get_num_elements()),
"ret" + std::to_string(ret.index), ret.offset_in_mem);
} else {
struct_components_.emplace_back(
ir_->get_array_type(ir_->get_primitive_type(ret.dt), 1),
ir_->get_array_type(
ir_->get_primitive_type(PrimitiveType::get(ret.dtype)), 1),
"ret" + std::to_string(ret.index), ret.offset_in_mem);
}
}
Expand Down
5 changes: 5 additions & 0 deletions taichi/program/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ struct RuntimeContext {
void set_device_allocation(int i, bool is_device_allocation_) {
is_device_allocation[i] = is_device_allocation_;
}

template <typename T>
T get_ret(int i) {
return taichi_union_cast_with_different_sizes<T>(result_buffer[i]);
}
#endif
};

Expand Down
29 changes: 28 additions & 1 deletion tests/cpp/aot/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "gtest/gtest.h"
#include "taichi/ir/ir_builder.h"
#include "taichi/ir/statements.h"
#include "taichi/inc/constants.h"
#include "taichi/program/program.h"
#ifdef TI_WITH_VULKAN
#include "taichi/backends/vulkan/aot_module_loader_impl.h"
Expand Down Expand Up @@ -29,7 +30,24 @@ using namespace lang;

auto aot_builder = program.make_aot_module_builder(Arch::vulkan);

std::unique_ptr<Kernel> kernel_init, kernel_ret;
std::unique_ptr<Kernel> kernel_init, kernel_ret, kernel_simple_ret;

{
/*
@ti.kernel
def ret() -> ti.f32:
sum = 0.2
return sum
*/
IRBuilder builder;
auto *sum = builder.create_local_var(PrimitiveType::f32);
builder.create_local_store(sum, builder.get_float32(0.2));
builder.create_return(builder.create_local_load(sum));

kernel_simple_ret =
std::make_unique<Kernel>(program, builder.extract_ir(), "simple_ret");
kernel_simple_ret->insert_ret(PrimitiveType::f32);
}

{
/*
Expand Down Expand Up @@ -79,6 +97,7 @@ using namespace lang;
kernel_ret->insert_ret(PrimitiveType::i32);
}

aot_builder->add("simple_ret", kernel_simple_ret.get());
aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1);
aot_builder->add("init", kernel_init.get());
aot_builder->add("ret", kernel_ret.get());
Expand All @@ -103,6 +122,7 @@ TEST(AotSaveLoad, Vulkan) {
std::make_unique<taichi::lang::MemoryPool>(Arch::vulkan, nullptr);
result_buffer = (taichi::uint64 *)memory_pool->allocate(
sizeof(taichi::uint64) * taichi_result_buffer_entries, 8);
host_ctx.result_buffer = result_buffer;

// Create Taichi Device for computation
lang::vulkan::VulkanDeviceCreator::Params evd_params;
Expand Down Expand Up @@ -132,6 +152,13 @@ TEST(AotSaveLoad, Vulkan) {
EXPECT_EQ(root_size, 64);
vulkan_runtime->add_root_buffer(root_size);

auto simple_ret_kernel = vk_module->get_kernel("simple_ret");
EXPECT_TRUE(simple_ret_kernel);

simple_ret_kernel->launch(&host_ctx);
vulkan_runtime->synchronize();
EXPECT_FLOAT_EQ(host_ctx.get_ret<float>(0), 0.2);

auto init_kernel = vk_module->get_kernel("init");
EXPECT_TRUE(init_kernel);

Expand Down