Skip to content

Commit

Permalink
[aot] Support return in vulkan aot
Browse files Browse the repository at this point in the history
Vulkan aot return didn't work since it was trying to serialize
`Datatype` which was a raw pointer. This PR fixes the issue by changing
it to a `PrimitiveTypeID` which is a `int`. Note it's possible that we
limit general `Kernel` args and rets to be primitive types only but leaving that
to a followup PR to avoid making this PR too complicated.
  • Loading branch information
Ailing Zhang committed Mar 21, 2022
1 parent b6e7e76 commit aff5aca
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 60 deletions.
32 changes: 3 additions & 29 deletions taichi/backends/vulkan/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ class AotDataConverter {
for (const auto &arg : in.ctx_attribs.args()) {
if (!arg.is_array) {
aot::ScalarArg scalar_arg{};
scalar_arg.dtype_name = arg.dt.to_string();
scalar_arg.dtype_name = PrimitiveType::get(arg.dtype).to_string();
scalar_arg.offset_in_args_buf = arg.offset_in_mem;
res.scalar_args[arg.index] = scalar_arg;
} else {
aot::ArrayArg arr_arg{};
arr_arg.dtype_name = arg.dt.to_string();
arr_arg.dtype_name = PrimitiveType::get(arg.dtype).to_string();
arr_arg.field_dim = arg.field_dim;
arr_arg.element_shape = arg.element_shape;
arr_arg.shape_offset_in_args_buf = arg.index * sizeof(int32_t);
Expand Down Expand Up @@ -105,32 +105,6 @@ AotModuleBuilderImpl::AotModuleBuilderImpl(
}
}

uint32_t AotModuleBuilderImpl::to_vk_dtype_enum(DataType dt) {
if (dt == PrimitiveType::u64) {
return 0;
} else if (dt == PrimitiveType::i64) {
return 1;
} else if (dt == PrimitiveType::u32) {
return 2;
} else if (dt == PrimitiveType::i32) {
return 3;
} else if (dt == PrimitiveType::u16) {
return 4;
} else if (dt == PrimitiveType::i16) {
return 5;
} else if (dt == PrimitiveType::u8) {
return 6;
} else if (dt == PrimitiveType::i8) {
return 7;
} else if (dt == PrimitiveType::f64) {
return 8;
} else if (dt == PrimitiveType::f32) {
return 9;
} else {
TI_NOT_IMPLEMENTED
}
}

std::string AotModuleBuilderImpl::write_spv_file(
const std::string &output_dir,
const TaskAttributes &k,
Expand Down Expand Up @@ -194,7 +168,7 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
aot::CompiledFieldData field_data;
field_data.field_name = identifier;
field_data.is_scalar = is_scalar;
field_data.dtype = to_vk_dtype_enum(dt);
field_data.dtype = static_cast<int>(dt->cast<PrimitiveType>()->type);
field_data.dtype_name = dt.to_string();
field_data.shape = shape;
field_data.mem_offset_in_parent = dense_desc.mem_offset_in_parent_cell;
Expand Down
2 changes: 0 additions & 2 deletions taichi/backends/vulkan/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const TaskAttributes &k,
const std::vector<uint32_t> &source_code) const;

uint32_t to_vk_dtype_enum(DataType dt);

const std::vector<CompiledSNodeStructs> &compiled_structs_;
TaichiAotData ti_aot_data_;
std::unique_ptr<Device> aot_target_device_;
Expand Down
22 changes: 11 additions & 11 deletions taichi/backends/vulkan/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,15 @@ class HostDeviceContextBlitter {
char *const device_base =
reinterpret_cast<char *>(device_->map(*device_args_buffer_));

#define TO_DEVICE(short_type, type) \
if (dt->is_primitive(PrimitiveTypeID::short_type)) { \
auto d = host_ctx_->get_arg<type>(i); \
reinterpret_cast<type *>(device_ptr)[0] = d; \
break; \
#define TO_DEVICE(short_type, type) \
if (arg.dtype == PrimitiveTypeID::short_type) { \
auto d = host_ctx_->get_arg<type>(i); \
reinterpret_cast<type *>(device_ptr)[0] = d; \
break; \
}

for (int i = 0; i < ctx_attribs_->args().size(); ++i) {
const auto &arg = ctx_attribs_->args()[i];
const auto dt = arg.dt;
char *device_ptr = device_base + arg.offset_in_mem;
do {
if (arg.is_array) {
Expand Down Expand Up @@ -118,13 +117,14 @@ class HostDeviceContextBlitter {
TO_DEVICE(f64, float64)
}
if (device_->get_cap(DeviceCapability::spirv_has_float16)) {
if (dt->is_primitive(PrimitiveTypeID::f16)) {
if (arg.dtype == PrimitiveTypeID::f16) {
auto d = fp16_ieee_from_fp32_value(host_ctx_->get_arg<float>(i));
reinterpret_cast<uint16 *>(device_ptr)[0] = d;
break;
}
}
TI_ERROR("Vulkan does not support arg type={}", data_type_name(arg.dt));
TI_ERROR("Vulkan does not support arg type={}",
PrimitiveType::get(arg.dtype).to_string());
} while (0);
}

Expand Down Expand Up @@ -196,8 +196,8 @@ class HostDeviceContextBlitter {
// *arg* on the host context.
const auto &ret = ctx_attribs_->rets()[i];
char *device_ptr = device_base + ret.offset_in_mem;
const auto dt = ret.dt;
const auto num = ret.stride / data_type_size(ret.dt);
const auto dt = PrimitiveType::get(ret.dtype);
const auto num = ret.stride / data_type_size(dt);
for (int j = 0; j < num; ++j) {
if (device_->get_cap(DeviceCapability::spirv_has_int8)) {
TO_HOST(i8, int8, j)
Expand Down Expand Up @@ -227,7 +227,7 @@ class HostDeviceContextBlitter {
}
}
TI_ERROR("Vulkan does not support return value type={}",
data_type_name(ret.dt));
data_type_name(PrimitiveType::get(ret.dtype)));
}
}
#undef TO_HOST
Expand Down
25 changes: 16 additions & 9 deletions taichi/codegen/spirv/kernel_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel)
rets_bytes_(0),
extra_args_bytes_(RuntimeContext::extra_args_size) {
arg_attribs_vec_.reserve(kernel.args.size());
// TODO: We should be able to limit Kernel args and rets to be primitive types
// as well but let's leave that as a followup up PR.
for (const auto &ka : kernel.args) {
ArgAttributes aa;
aa.dt = ka.dt;
const size_t dt_bytes = data_type_size(aa.dt);
TI_ASSERT(ka.dt->is<PrimitiveType>());
aa.dtype = ka.dt->cast<PrimitiveType>()->type;
const size_t dt_bytes = data_type_size(ka.dt);
aa.is_array = ka.is_array;
if (aa.is_array) {
aa.field_dim = ka.total_dim - ka.element_shape.size();
Expand All @@ -70,13 +73,16 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel)
RetAttributes ra;
size_t dt_bytes{0};
if (auto tensor_type = kr.dt->cast<TensorType>()) {
ra.dt = tensor_type->get_element_type();
dt_bytes = data_type_size(ra.dt);
auto tensor_dtype = tensor_type->get_element_type();
TI_ASSERT(tensor_dtype->is<PrimitiveType>());
ra.dtype = tensor_dtype->cast<PrimitiveType>()->type;
dt_bytes = data_type_size(tensor_dtype);
ra.is_array = true;
ra.stride = tensor_type->get_num_elements() * dt_bytes;
} else {
ra.dt = kr.dt;
dt_bytes = data_type_size(ra.dt);
TI_ASSERT(kr.dt->is<PrimitiveType>());
ra.dtype = kr.dt->cast<PrimitiveType>()->type;
dt_bytes = data_type_size(kr.dt);
ra.is_array = false;
ra.stride = dt_bytes;
}
Expand All @@ -88,9 +94,10 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel)
size_t bytes = offset;
for (int i = 0; i < vec->size(); ++i) {
auto &attribs = (*vec)[i];
const size_t dt_bytes = (attribs.is_array && !is_ret)
? sizeof(uint64_t)
: data_type_size(attribs.dt);
const size_t dt_bytes =
(attribs.is_array && !is_ret)
? sizeof(uint64_t)
: data_type_size(PrimitiveType::get(attribs.dtype));
// Align bytes to the nearest multiple of dt_bytes
bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes;
attribs.offset_in_mem = bytes;
Expand Down
10 changes: 8 additions & 2 deletions taichi/codegen/spirv/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,18 @@ class KernelContextAttributes {
size_t offset_in_mem{0};
// Index of the input arg or the return value in the host `Context`
int index{-1};
DataType dt;
PrimitiveTypeID dtype{PrimitiveTypeID::unknown};
bool is_array{false};
std::vector<int> element_shape;
std::size_t field_dim{0};

TI_IO_DEF(stride, offset_in_mem, index, is_array, element_shape, field_dim);
TI_IO_DEF(stride,
offset_in_mem,
index,
dtype,
is_array,
element_shape,
field_dim);
};

public:
Expand Down
14 changes: 8 additions & 6 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class TaskCodegen : public IRVisitor {
// ir_->int_immediate_number(ir_->i32_type(), offset_in_mem);
// ir_->register_value(stmt->raw_name(), val);
} else {
const auto dt = arg_attribs.dt;
const auto dt = PrimitiveType::get(arg_attribs.dtype);
const auto val_type = ir_->get_primitive_type(dt);
spirv::Value buffer_val = ir_->make_value(
spv::OpAccessChain,
Expand Down Expand Up @@ -1806,9 +1806,9 @@ class TaskCodegen : public IRVisitor {
"arg_ptr" + std::to_string(arg.index),
arg.offset_in_mem);
} else {
struct_components_.emplace_back(ir_->get_primitive_type(arg.dt),
"arg" + std::to_string(arg.index),
arg.offset_in_mem);
struct_components_.emplace_back(
ir_->get_primitive_type(PrimitiveType::get(arg.dtype)),
"arg" + std::to_string(arg.index), arg.offset_in_mem);
}
}
// A compromise for use in constants buffer
Expand All @@ -1833,15 +1833,17 @@ class TaskCodegen : public IRVisitor {
// Now we only have one ret
TI_ASSERT(ctx_attribs_->rets().size() == 1);
for (auto &ret : ctx_attribs_->rets()) {
if (auto tensor_type = ret.dt->cast<TensorType>()) {
if (auto tensor_type =
PrimitiveType::get(ret.dtype)->cast<TensorType>()) {
struct_components_.emplace_back(
ir_->get_array_type(
ir_->get_primitive_type(tensor_type->get_element_type()),
tensor_type->get_num_elements()),
"ret" + std::to_string(ret.index), ret.offset_in_mem);
} else {
struct_components_.emplace_back(
ir_->get_array_type(ir_->get_primitive_type(ret.dt), 1),
ir_->get_array_type(
ir_->get_primitive_type(PrimitiveType::get(ret.dtype)), 1),
"ret" + std::to_string(ret.index), ret.offset_in_mem);
}
}
Expand Down
5 changes: 5 additions & 0 deletions taichi/program/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ struct RuntimeContext {
void set_device_allocation(int i, bool is_device_allocation_) {
is_device_allocation[i] = is_device_allocation_;
}

template <typename T>
T get_ret(int i) {
return taichi_union_cast_with_different_sizes<T>(result_buffer[i]);
}
#endif
};

Expand Down
29 changes: 28 additions & 1 deletion tests/cpp/aot/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "gtest/gtest.h"
#include "taichi/ir/ir_builder.h"
#include "taichi/ir/statements.h"
#include "taichi/inc/constants.h"
#include "taichi/program/program.h"
#ifdef TI_WITH_VULKAN
#include "taichi/backends/vulkan/aot_module_loader_impl.h"
Expand Down Expand Up @@ -29,7 +30,24 @@ using namespace lang;

auto aot_builder = program.make_aot_module_builder(Arch::vulkan);

std::unique_ptr<Kernel> kernel_init, kernel_ret;
std::unique_ptr<Kernel> kernel_init, kernel_ret, kernel_simple_ret;

{
/*
@ti.kernel
def ret() -> ti.f32:
sum = 0.2
return sum
*/
IRBuilder builder;
auto *sum = builder.create_local_var(PrimitiveType::f32);
builder.create_local_store(sum, builder.get_float32(0.2));
builder.create_return(builder.create_local_load(sum));

kernel_simple_ret =
std::make_unique<Kernel>(program, builder.extract_ir(), "simple_ret");
kernel_simple_ret->insert_ret(PrimitiveType::f32);
}

{
/*
Expand Down Expand Up @@ -79,6 +97,7 @@ using namespace lang;
kernel_ret->insert_ret(PrimitiveType::i32);
}

aot_builder->add("simple_ret", kernel_simple_ret.get());
aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1);
aot_builder->add("init", kernel_init.get());
aot_builder->add("ret", kernel_ret.get());
Expand All @@ -103,6 +122,7 @@ TEST(AotSaveLoad, Vulkan) {
std::make_unique<taichi::lang::MemoryPool>(Arch::vulkan, nullptr);
result_buffer = (taichi::uint64 *)memory_pool->allocate(
sizeof(taichi::uint64) * taichi_result_buffer_entries, 8);
host_ctx.result_buffer = result_buffer;

// Create Taichi Device for computation
lang::vulkan::VulkanDeviceCreator::Params evd_params;
Expand Down Expand Up @@ -132,6 +152,13 @@ TEST(AotSaveLoad, Vulkan) {
EXPECT_EQ(root_size, 64);
vulkan_runtime->add_root_buffer(root_size);

auto simple_ret_kernel = vk_module->get_kernel("simple_ret");
EXPECT_TRUE(simple_ret_kernel);

simple_ret_kernel->launch(&host_ctx);
vulkan_runtime->synchronize();
EXPECT_FLOAT_EQ(host_ctx.get_ret<float>(0), 0.2);

auto init_kernel = vk_module->get_kernel("init");
EXPECT_TRUE(init_kernel);

Expand Down

0 comments on commit aff5aca

Please sign in to comment.