Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[infra] Refactor Vulkan runtime into true Common Runtime #5058

Merged
merged 7 commits into from
Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmake/TaichiCXXFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ endif ()
# Do not enable lto for APPLE since it made linking extremely slow.
if (WIN32)
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -flto=thin")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS} -flto=thin")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} -flto=thin")
endif()
endif()

Expand Down
20 changes: 7 additions & 13 deletions cmake/TaichiCore.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -372,24 +372,21 @@ add_subdirectory(external/SPIRV-Tools)
# https://github.com/KhronosGroup/SPIRV-Tools/issues/1569#issuecomment-390250792
target_link_libraries(${CORE_LIBRARY_NAME} PRIVATE SPIRV-Tools-opt ${SPIRV_TOOLS})

target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/SPIRV-Headers/include)
target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/SPIRV-Reflect)

add_subdirectory(taichi/runtime/gfx)
target_link_libraries(${CORE_LIBRARY_NAME} PRIVATE gfx_runtime)

# Vulkan Device API
if (TI_WITH_VULKAN)
include_directories(SYSTEM external/Vulkan-Headers/include)

include_directories(SYSTEM external/volk)

target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/SPIRV-Headers/include)
target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/SPIRV-Reflect)

# By specifying SYSTEM, we suppressed the warnings from third-party headers.
target_include_directories(${CORE_LIBRARY_NAME} SYSTEM PRIVATE external/VulkanMemoryAllocator/include)

if (LINUX)
# shaderc requires pthread
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
target_link_libraries(${CORE_LIBRARY_NAME} PRIVATE Threads::Threads)
endif()

if (APPLE)
find_library(MOLTEN_VK libMoltenVK.dylib PATHS $HOMEBREW_CELLAR/molten-vk $VULKAN_SDK REQUIRED)
configure_file(${MOLTEN_VK} ${CMAKE_BINARY_DIR}/libMoltenVK.dylib COPYONLY)
Expand All @@ -398,9 +395,6 @@ if (TI_WITH_VULKAN)
install(FILES ${CMAKE_BINARY_DIR}/libMoltenVK.dylib DESTINATION ${INSTALL_LIB_DIR}/runtime)
endif()
endif()

add_subdirectory(taichi/runtime/vulkan)
target_link_libraries(${CORE_LIBRARY_NAME} PRIVATE vulkan_runtime)
endif ()


Expand Down
15 changes: 7 additions & 8 deletions taichi/aot/module_loader.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "taichi/aot/module_loader.h"

#include "taichi/backends/vulkan/aot_module_loader_impl.h"
#include "taichi/runtime/gfx/aot_module_loader_impl.h"
#include "taichi/backends/metal/aot_module_loader_impl.h"

namespace taichi {
Expand Down Expand Up @@ -32,19 +32,18 @@ Kernel *KernelTemplate::get_kernel(
std::unique_ptr<Module> Module::load(Arch arch, std::any mod_params) {
if (arch == Arch::vulkan) {
#ifdef TI_WITH_VULKAN
return vulkan::make_aot_module(mod_params);
#else
TI_NOT_IMPLEMENTED
return gfx::make_aot_module(mod_params, arch);
#endif
} else if (arch == Arch::dx11) {
#ifdef TI_WITH_DX11
return gfx::make_aot_module(mod_params, arch);
#endif
} else if (arch == Arch::metal) {
#ifdef TI_WITH_METAL
return metal::make_aot_module(mod_params);
#else
TI_NOT_IMPLEMENTED
#endif
} else {
TI_NOT_IMPLEMENTED;
}
TI_NOT_IMPLEMENTED;
}

Kernel *Module::get_kernel(const std::string &name) {
Expand Down
20 changes: 12 additions & 8 deletions taichi/backends/dx/dx_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void dump_buffer(ID3D11Device *device,

void check_dx_error(HRESULT hr, const char *msg) {
if (!SUCCEEDED(hr)) {
TI_ERROR("Error in {}: {}", msg, hr);
TI_ERROR("Error in {}: {:x}", msg, uint32_t(hr));
}
}

Expand Down Expand Up @@ -593,16 +593,21 @@ DeviceAllocation Dx11Device::allocate_memory(const AllocParams &params) {

void Dx11Device::dealloc_memory(DeviceAllocation handle) {
uint32_t alloc_id = handle.alloc_id;
if (alloc_id_to_buffer_.count(alloc_id) == 0)
return;
if (alloc_id_to_buffer_.find(alloc_id) == alloc_id_to_buffer_.end())
TI_ERROR("Invalid handle, possible double free?");
ID3D11Buffer *buf = alloc_id_to_buffer_[alloc_id];
buf->Release();
alloc_id_to_buffer_.erase(alloc_id);
ID3D11UnorderedAccessView *uav = alloc_id_to_uav_[alloc_id];
uav->Release();
ID3D11Buffer *cpucopy = alloc_id_to_cpucopy_[alloc_id];
if (cpucopy)
cpucopy->Release();
if (alloc_id_to_cpucopy_.find(alloc_id) != alloc_id_to_cpucopy_.end()) {
alloc_id_to_cpucopy_[alloc_id]->Release();
alloc_id_to_cpucopy_.erase(alloc_id);
}
if (alloc_id_to_cb_copy_.find(alloc_id) != alloc_id_to_cb_copy_.end()) {
alloc_id_to_cb_copy_[alloc_id]->Release();
alloc_id_to_cb_copy_.erase(alloc_id);
}
alloc_id_to_uav_.erase(alloc_id);
}

Expand Down Expand Up @@ -724,10 +729,9 @@ ID3D11UnorderedAccessView *Dx11Device::alloc_id_to_uav(uint32_t alloc_id) {
}

ID3D11Buffer *Dx11Device::create_or_get_cb_buffer(uint32_t alloc_id) {
if (alloc_id_to_cb_copy_.count(alloc_id) > 0) {
if (alloc_id_to_cb_copy_.find(alloc_id) != alloc_id_to_cb_copy_.end()) {
return alloc_id_to_cb_copy_[alloc_id];
}
assert(alloc_id_to_buffer_.count(alloc_id) > 0);
ID3D11Buffer *buf = alloc_id_to_buffer_[alloc_id];
ID3D11Buffer *cb_buf;
HRESULT hr = create_constant_buffer_copy(device_, buf, &cb_buf);
Expand Down
54 changes: 40 additions & 14 deletions taichi/backends/dx/dx_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
#include "taichi/backends/dx/dx_program.h"

#include "taichi/backends/dx/dx_device.h"
#include "taichi/backends/vulkan/snode_tree_manager.h"
#include "taichi/runtime/gfx/aot_module_builder_impl.h"
#include "taichi/runtime/gfx/snode_tree_manager.h"
#include "taichi/runtime/gfx/aot_module_loader_impl.h"

namespace taichi {
namespace lang {
namespace directx11 {

FunctionType compile_to_executable(Kernel *kernel,
vulkan::VkRuntime *runtime,
vulkan::SNodeTreeManager *snode_tree_mgr) {
gfx::GfxRuntime *runtime,
gfx::SNodeTreeManager *snode_tree_mgr) {
auto handle = runtime->register_taichi_kernel(
std::move(vulkan::run_codegen(kernel, runtime->get_ti_device(),
snode_tree_mgr->get_compiled_structs())));
std::move(gfx::run_codegen(kernel, runtime->get_ti_device(),
snode_tree_mgr->get_compiled_structs())));
return [runtime, handle](RuntimeContext &ctx) {
runtime->launch_kernel(handle, &ctx);
};
Expand All @@ -40,28 +42,52 @@ void Dx11ProgramImpl::materialize_runtime(MemoryPool *memory_pool,

device_ = std::make_shared<directx11::Dx11Device>();

vulkan::VkRuntime::Params params;
gfx::GfxRuntime::Params params;
params.host_result_buffer = *result_buffer_ptr;
params.device = device_.get();
runtime_ = std::make_unique<vulkan::VkRuntime>(std::move(params));
snode_tree_mgr_ = std::make_unique<vulkan::SNodeTreeManager>(runtime_.get());
runtime_ = std::make_unique<gfx::GfxRuntime>(std::move(params));
snode_tree_mgr_ = std::make_unique<gfx::SNodeTreeManager>(runtime_.get());
}

void Dx11ProgramImpl::synchronize() {
TI_NOT_IMPLEMENTED;
void Dx11ProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
if (runtime_) {
snode_tree_mgr_->materialize_snode_tree(tree);
} else {
gfx::CompiledSNodeStructs compiled_structs =
gfx::compile_snode_structs(*tree->root());
aot_compiled_snode_structs_.push_back(compiled_structs);
}
}

void Dx11ProgramImpl::materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer_ptr) {
uint64 *result_buffer) {
snode_tree_mgr_->materialize_snode_tree(tree);
}

std::unique_ptr<AotModuleBuilder> Dx11ProgramImpl::make_aot_module_builder() {
return nullptr;
if (runtime_) {
return std::make_unique<gfx::AotModuleBuilderImpl>(
snode_tree_mgr_->get_compiled_structs(), Arch::dx11);
} else {
return std::make_unique<gfx::AotModuleBuilderImpl>(
aot_compiled_snode_structs_, Arch::dx11);
}
}

void Dx11ProgramImpl::destroy_snode_tree(SNodeTree *snode_tree) {
TI_NOT_IMPLEMENTED;
DeviceAllocation Dx11ProgramImpl::allocate_memory_ndarray(
std::size_t alloc_size,
uint64 *result_buffer) {
return get_compute_device()->allocate_memory(
{alloc_size, /*host_write=*/false, /*host_read=*/false,
/*export_sharing=*/false});
}

std::unique_ptr<aot::Kernel> Dx11ProgramImpl::make_aot_kernel(Kernel &kernel) {
spirv::lower(&kernel);
std::vector<gfx::CompiledSNodeStructs> compiled_structs;
gfx::GfxRuntime::RegisterParams kparams =
gfx::run_codegen(&kernel, get_compute_device(), compiled_structs);
return std::make_unique<gfx::KernelImpl>(runtime_.get(), std::move(kparams));
}

} // namespace lang
Expand Down
55 changes: 44 additions & 11 deletions taichi/backends/dx/dx_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#ifdef TI_WITH_DX11

#include "taichi/backends/dx/dx_device.h"
#include "taichi/runtime/vulkan/runtime.h"
#include "taichi/backends/vulkan/snode_tree_manager.h"
#include "taichi/runtime/gfx/runtime.h"
#include "taichi/runtime/gfx/snode_tree_manager.h"
#include "taichi/program/program_impl.h"

namespace taichi {
Expand All @@ -13,26 +13,59 @@ namespace lang {
class Dx11ProgramImpl : public ProgramImpl {
public:
Dx11ProgramImpl(CompileConfig &config);

FunctionType compile(Kernel *kernel, OffloadedStmt *offloaded) override;

std::size_t get_snode_num_dynamically_allocated(
SNode *snode,
uint64 *result_buffer) override {
return 0;
return 0; // TODO: support sparse
}
std::unique_ptr<AotModuleBuilder> make_aot_module_builder();

void compile_snode_tree_types(SNodeTree *tree) override;

void materialize_runtime(MemoryPool *memory_pool,
KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) override;
virtual void materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer_ptr) override;
virtual void destroy_snode_tree(SNodeTree *snode_tree) override;
void synchronize() override;

void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override;

void synchronize() override {
runtime_->synchronize();
}

StreamSemaphore flush() override {
return runtime_->flush();
}

std::unique_ptr<AotModuleBuilder> make_aot_module_builder() override;

void destroy_snode_tree(SNodeTree *snode_tree) override {
TI_ASSERT(snode_tree_mgr_ != nullptr);
snode_tree_mgr_->destroy_snode_tree(snode_tree);
}

DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size,
uint64 *result_buffer) override;

Device *get_compute_device() override {
return device_.get();
}

Device *get_graphics_device() override {
return device_.get();
}

DevicePtr get_snode_tree_device_ptr(int tree_id) override {
return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id);
}

std::unique_ptr<aot::Kernel> make_aot_kernel(Kernel &kernel) override;

private:
std::shared_ptr<directx11::Dx11Device> device_{nullptr};
std::unique_ptr<vulkan::VkRuntime> runtime_{nullptr};
std::unique_ptr<vulkan::SNodeTreeManager> snode_tree_mgr_{nullptr};
std::unique_ptr<gfx::GfxRuntime> runtime_{nullptr};
std::unique_ptr<gfx::SNodeTreeManager> snode_tree_mgr_{nullptr};
std::vector<spirv::CompiledSNodeStructs> aot_compiled_snode_structs_;
};

} // namespace lang
Expand Down
17 changes: 13 additions & 4 deletions taichi/backends/opengl/opengl_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ void GLResourceBinder::rw_buffer(uint32_t set,
DeviceAllocation alloc) {
TI_ASSERT_INFO(set == 0, "OpenGL only supports set = 0, requested set = {}",
set);
binding_map_[binding] = alloc.alloc_id;
ssbo_binding_map_[binding] = alloc.alloc_id;
}

void GLResourceBinder::buffer(uint32_t set,
Expand All @@ -211,7 +211,9 @@ void GLResourceBinder::buffer(uint32_t set,
void GLResourceBinder::buffer(uint32_t set,
uint32_t binding,
DeviceAllocation alloc) {
rw_buffer(set, binding, alloc);
TI_ASSERT_INFO(set == 0, "OpenGL only supports set = 0, requested set = {}",
set);
ubo_binding_map_[binding] = alloc.alloc_id;
}

void GLResourceBinder::image(uint32_t set,
Expand Down Expand Up @@ -295,10 +297,17 @@ void GLCommandList::bind_pipeline(Pipeline *p) {

void GLCommandList::bind_resources(ResourceBinder *_binder) {
GLResourceBinder *binder = static_cast<GLResourceBinder *>(_binder);
for (auto &[binding, buffer] : binder->binding_map()) {
for (auto &[binding, buffer] : binder->ssbo_binding_map()) {
auto cmd = std::make_unique<CmdBindBufferToIndex>();
cmd->buffer = buffer;
cmd->index = binding;
recorded_commands_.push_back(std::move(cmd));
}
for (auto &[binding, buffer] : binder->ubo_binding_map()) {
auto cmd = std::make_unique<CmdBindBufferToIndex>();
cmd->buffer = buffer;
cmd->index = binding;
cmd->target = GL_UNIFORM_BUFFER;
recorded_commands_.push_back(std::move(cmd));
}
}
Expand Down Expand Up @@ -682,7 +691,7 @@ void GLCommandList::CmdBindPipeline::execute() {
}

void GLCommandList::CmdBindBufferToIndex::execute() {
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, index, buffer);
glBindBufferBase(target, index, buffer);
check_opengl_error("glBindBufferBase");
}

Expand Down
12 changes: 9 additions & 3 deletions taichi/backends/opengl/opengl_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,17 @@ class GLResourceBinder : public ResourceBinder {
// index_width = 2 -> uint16 index
void index_buffer(DevicePtr ptr, size_t index_width) override;

const std::unordered_map<uint32_t, GLuint> &binding_map() {
return binding_map_;
const std::unordered_map<uint32_t, GLuint> &ssbo_binding_map() {
return ssbo_binding_map_;
}

const std::unordered_map<uint32_t, GLuint> &ubo_binding_map() {
return ubo_binding_map_;
}

private:
std::unordered_map<uint32_t, GLuint> binding_map_;
std::unordered_map<uint32_t, GLuint> ssbo_binding_map_;
std::unordered_map<uint32_t, GLuint> ubo_binding_map_;
};

class GLPipeline : public Pipeline {
Expand Down Expand Up @@ -141,6 +146,7 @@ class GLCommandList : public CommandList {
struct CmdBindBufferToIndex : public Cmd {
GLuint buffer{0};
GLuint index{0};
GLenum target{GL_SHADER_STORAGE_BUFFER};
void execute() override;
};

Expand Down
Loading