Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aot] Improve C++ wrapper implementation #6146

Merged
merged 2 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 96 additions & 47 deletions c_api/include/taichi/cpp/taichi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ THandle move_handle(THandle &handle) {
class Memory {
TiRuntime runtime_{TI_NULL_HANDLE};
TiMemory memory_{TI_NULL_HANDLE};
size_t size_{0};
bool should_destroy_{false};

public:
Expand All @@ -88,10 +89,14 @@ class Memory {
Memory(Memory &&b)
: runtime_(detail::move_handle(b.runtime_)),
memory_(detail::move_handle(b.memory_)),
size_(std::exchange(b.size_, 0)),
should_destroy_(std::exchange(b.should_destroy_, false)) {
}
Memory(TiRuntime runtime, TiMemory memory, bool should_destroy)
: runtime_(runtime), memory_(memory), should_destroy_(should_destroy) {
Memory(TiRuntime runtime, TiMemory memory, size_t size, bool should_destroy)
: runtime_(runtime),
memory_(memory),
size_(size),
should_destroy_(should_destroy) {
}
~Memory() {
destroy();
Expand All @@ -102,10 +107,36 @@ class Memory {
destroy();
runtime_ = detail::move_handle(b.runtime_);
memory_ = detail::move_handle(b.memory_);
size_ = std::exchange(b.size_, 0);
should_destroy_ = std::exchange(b.should_destroy_, false);
return *this;
}

void *map() const {
return ti_map_memory(runtime_, memory_);
}
void unmap() const {
ti_unmap_memory(runtime_, memory_);
}

inline void read(void *dst, size_t size) const {
PENGUINLIONG marked this conversation as resolved.
Show resolved Hide resolved
void *src = map();
if (src != nullptr) {
std::memcpy(dst, src, size);
}
unmap();
}
inline void write(const void *src, size_t size) const {
void *dst = map();
if (dst != nullptr) {
std::memcpy(dst, src, size);
}
unmap();
}

constexpr size_t size() const {
return size_;
}
constexpr TiMemory memory() const {
return memory_;
}
Expand All @@ -116,32 +147,28 @@ class Memory {

template <typename T>
class NdArray {
TiRuntime runtime_{TI_NULL_HANDLE};
Memory memory_{};
TiNdArray ndarray_{};
bool should_destroy_{false};

public:
constexpr bool is_valid() const {
return ndarray_.memory != nullptr;
return memory_.is_valid();
}
inline void destroy() {
if (should_destroy_) {
ti_free_memory(runtime_, ndarray_.memory);
ndarray_.memory = TI_NULL_HANDLE;
should_destroy_ = false;
}
memory_.destroy();
}

NdArray() {
}
NdArray(const NdArray<T> &) = delete;
NdArray(NdArray<T> &&b)
: runtime_(detail::move_handle(b.runtime_)),
ndarray_(std::exchange(b.ndarray_, {})),
should_destroy_(std::exchange(b.should_destroy_, false)) {
: memory_(std::move(b.memory_)), ndarray_(std::exchange(b.ndarray_, {})) {
}
NdArray(TiRuntime runtime, const TiNdArray &ndarray, bool should_destroy)
: runtime_(runtime), ndarray_(ndarray), should_destroy_(should_destroy) {
NdArray(Memory &&memory, const TiNdArray &ndarray)
: memory_(std::move(memory)), ndarray_(ndarray) {
if (ndarray.memory != memory_) {
ti_set_last_error(TI_ERROR_INVALID_ARGUMENT, "ndarray.memory != memory");
}
}
~NdArray() {
destroy();
Expand All @@ -150,44 +177,56 @@ class NdArray {
NdArray<T> &operator=(const NdArray<T> &) = delete;
NdArray<T> &operator=(NdArray<T> &&b) {
destroy();
runtime_ = detail::move_handle(b.runtime_);
memory_ = std::move(b.memory_);
ndarray_ = std::exchange(b.ndarray_, {});
should_destroy_ = std::exchange(b.should_destroy_, false);
return *this;
}

inline void *map() {
return ti_map_memory(runtime_, ndarray_.memory);
inline void *map() const {
return memory_.map();
}
inline void unmap() {
return ti_unmap_memory(runtime_, ndarray_.memory);
inline void unmap() const {
return memory_.unmap();
}

inline void read(T *dst, size_t size) {
T *src = (T *)map();
if (src != nullptr) {
std::memcpy(dst, src, size);
}
unmap();
inline void read(T *dst, size_t size) const {
memory_.read(dst, size);
}
inline void read(std::vector<T> &dst) {
inline void read(std::vector<T> &dst) const {
read(dst.data(), dst.size() * sizeof(T));
}
inline void write(const T *src, size_t size) {
T *dst = (T *)map();
if (dst != nullptr) {
std::memcpy(dst, src, size);
}
unmap();
template <typename U>
inline void read(std::vector<U> &dst) const {
static_assert(sizeof(U) % sizeof(T) == 0,
"sizeof(U) must be a multiple of sizeof(T)");
read((T *)dst.data(), dst.size() * sizeof(U));
}
inline void write(const T *src, size_t size) const {
memory_.write(src, size);
}
inline void write(const std::vector<T> &src) {
inline void write(const std::vector<T> &src) const {
write(src.data(), src.size() * sizeof(T));
}
template <typename U>
inline void write(const std::vector<U> &src) const {
static_assert(sizeof(U) % sizeof(T) == 0,
"sizeof(U) must be a multiple of sizeof(T)");
write((const T *)src.data(), src.size() * sizeof(U));
}

constexpr TiMemory memory() const {
return ndarray_.memory;
constexpr TiDataType elem_type() const {
return ndarray_.elem_type;
}
constexpr const TiNdShape &shape() const {
return ndarray_.shape;
}
constexpr const TiNdShape &elem_shape() const {
return ndarray_.elem_shape;
}
constexpr const Memory &memory() const {
return memory_;
}
constexpr TiNdArray ndarray() const {
constexpr const TiNdArray &ndarray() const {
return ndarray_;
}
constexpr operator TiNdArray() const {
Expand Down Expand Up @@ -291,6 +330,9 @@ class Texture {
return *this;
}

constexpr const Image &image() const {
return image_;
}
constexpr TiTexture texture() const {
return texture_;
}
Expand Down Expand Up @@ -598,6 +640,7 @@ class Event {
};

class Runtime {
TiArch arch_{TI_ARCH_MAX_ENUM};
TiRuntime runtime_{TI_NULL_HANDLE};
bool should_destroy_{false};

Expand All @@ -617,14 +660,15 @@ class Runtime {
}
Runtime(const Runtime &) = delete;
Runtime(Runtime &&b)
: runtime_(detail::move_handle(b.runtime_)),
: arch_(std::exchange(b.arch_, TI_ARCH_MAX_ENUM)),
runtime_(detail::move_handle(b.runtime_)),
should_destroy_(std::exchange(b.should_destroy_, false)) {
}
Runtime(TiArch arch)
: runtime_(ti_create_runtime(arch)), should_destroy_(true) {
: arch_(arch), runtime_(ti_create_runtime(arch)), should_destroy_(true) {
}
Runtime(TiRuntime runtime, bool should_destroy)
: runtime_(runtime), should_destroy_(should_destroy) {
Runtime(TiArch arch, TiRuntime runtime, bool should_destroy)
: arch_(arch), runtime_(runtime), should_destroy_(should_destroy) {
}
~Runtime() {
destroy();
Expand All @@ -639,7 +683,7 @@ class Runtime {

Memory allocate_memory(const TiMemoryAllocateInfo &allocate_info) {
TiMemory memory = ti_allocate_memory(runtime_, &allocate_info);
return Memory(runtime_, memory, true);
return Memory(runtime_, memory, allocate_info.size, true);
}
Memory allocate_memory(size_t size) {
TiMemoryAllocateInfo allocate_info{};
Expand All @@ -648,8 +692,8 @@ class Runtime {
return allocate_memory(allocate_info);
}
template <typename T>
NdArray<T> allocate_ndarray(std::vector<uint32_t> shape,
std::vector<uint32_t> elem_shape,
NdArray<T> allocate_ndarray(const std::vector<uint32_t> &shape = {},
const std::vector<uint32_t> &elem_shape = {},
bool host_access = false) {
size_t size = sizeof(T);
TiNdArray ndarray{};
Expand All @@ -666,13 +710,15 @@ class Runtime {
}
ndarray.elem_shape.dim_count = elem_shape.size();
ndarray.elem_type = detail::templ2dtype<T>::value;

TiMemoryAllocateInfo allocate_info{};
allocate_info.size = size;
allocate_info.host_read = host_access;
allocate_info.host_write = host_access;
allocate_info.usage = TI_MEMORY_USAGE_STORAGE_BIT;
ndarray.memory = ti_allocate_memory(runtime_, &allocate_info);
return NdArray<T>(runtime_, std::move(ndarray), true);
Memory memory = allocate_memory(allocate_info);
ndarray.memory = memory;
return NdArray<T>(std::move(memory), ndarray);
}

Image allocate_image(const TiImageAllocateInfo &allocate_info) {
Expand Down Expand Up @@ -734,6 +780,9 @@ class Runtime {
ti_wait(runtime_);
}

constexpr TiArch arch() const {
return arch_;
}
constexpr TiRuntime runtime() const {
return runtime_;
}
Expand Down
4 changes: 4 additions & 0 deletions taichi/rhi/vulkan/vulkan_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,11 @@ VulkanCommandList::VulkanCommandList(VulkanDevice *ti_device,
: ti_device_(ti_device),
stream_(stream),
device_(ti_device->vk_device()),
#if !defined(__APPLE__)
query_pool_(vkapi::create_query_pool(ti_device->vk_device())),
#else
query_pool_(),
#endif
buffer_(buffer) {
VkCommandBufferBeginInfo info{};
info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
Expand Down
2 changes: 1 addition & 1 deletion taichi/rhi/vulkan/vulkan_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ bool VulkanLoader::init(PFN_vkGetInstanceProcAddr get_proc_addr) {
// (penguinliong) So that MoltenVK instances can be imported.
if (get_proc_addr != nullptr) {
volkInitializeCustom(get_proc_addr);
initialized = check_vulkan_device();
initialized = true;
return;
}
#if defined(__APPLE__)
Expand Down