diff --git a/c_api/include/taichi/taichi_vulkan.h b/c_api/include/taichi/taichi_vulkan.h index fc1e58ce58d3d..74433853a5d83 100644 --- a/c_api/include/taichi/taichi_vulkan.h +++ b/c_api/include/taichi/taichi_vulkan.h @@ -8,6 +8,7 @@ extern "C" { // structure.vulkan_runtime_interop_info typedef struct TiVulkanRuntimeInteropInfo { + PFN_vkGetInstanceProcAddr get_instance_proc_addr; uint32_t api_version; VkInstance instance; VkPhysicalDevice physical_device; diff --git a/c_api/src/taichi_vulkan_impl.cpp b/c_api/src/taichi_vulkan_impl.cpp index bf458b9601e6f..9d2c1ac637a42 100644 --- a/c_api/src/taichi_vulkan_impl.cpp +++ b/c_api/src/taichi_vulkan_impl.cpp @@ -18,7 +18,8 @@ VulkanRuntimeImported::Workaround::Workaround( : vk_device{} { // FIXME: This part is copied from `vulkan_runtime_creator.cpp` which should // be refactorized I guess. - if (!taichi::lang::vulkan::VulkanLoader::instance().init()) { + if (!taichi::lang::vulkan::VulkanLoader::instance().init( + params.get_proc_addr)) { throw std::runtime_error("Error loading vulkan"); } taichi::lang::vulkan::VulkanLoader::instance().load_instance(params.instance); @@ -162,6 +163,7 @@ TiRuntime ti_import_vulkan_runtime( TI_CAPI_ARGUMENT_NULL_RV(interop_info->device); taichi::lang::vulkan::VulkanDevice::Params params{}; + params.get_proc_addr = interop_info->get_instance_proc_addr; params.instance = interop_info->instance; params.physical_device = interop_info->physical_device; params.device = interop_info->device; diff --git a/c_api/taichi.json b/c_api/taichi.json index dd5fa4d73aa07..caf75e43296b3 100644 --- a/c_api/taichi.json +++ b/c_api/taichi.json @@ -861,6 +861,10 @@ "name": "vulkan_runtime_interop_info", "type": "structure", "fields": [ + { + "name": "get_instance_proc_addr", + "type": "PFN_vkGetInstanceProcAddr" + }, { "name": "api_version", "type": "uint32_t" diff --git a/misc/generate_c_api.py b/misc/generate_c_api.py index 36379f16b1281..fde1138ddc2ea 100644 --- a/misc/generate_c_api.py +++ b/misc/generate_c_api.py @@ -169,6 +169,7 @@ def generate_module_header(module): BuiltInType("VkImageLayout", "VkImageLayout"), BuiltInType("VkImageUsageFlags", "VkImageUsageFlags"), BuiltInType("VkImageViewType", "VkImageViewType"), + BuiltInType("PFN_vkGetInstanceProcAddr", "PFN_vkGetInstanceProcAddr"), BuiltInType("char", "char"), } diff --git a/taichi/rhi/vulkan/vulkan_device.h b/taichi/rhi/vulkan/vulkan_device.h index eacedcbbb80e8..3c760075589e6 100644 --- a/taichi/rhi/vulkan/vulkan_device.h +++ b/taichi/rhi/vulkan/vulkan_device.h @@ -564,6 +564,7 @@ class VulkanStream : public Stream { class TI_DLL_EXPORT VulkanDevice : public GraphicsDevice { public: struct Params { + PFN_vkGetInstanceProcAddr get_proc_addr{nullptr}; VkInstance instance; VkPhysicalDevice physical_device; VkDevice device; diff --git a/taichi/rhi/vulkan/vulkan_loader.cpp b/taichi/rhi/vulkan/vulkan_loader.cpp index 662357e297b15..9bec82595ea8e 100644 --- a/taichi/rhi/vulkan/vulkan_loader.cpp +++ b/taichi/rhi/vulkan/vulkan_loader.cpp @@ -88,11 +88,17 @@ bool VulkanLoader::check_vulkan_device() { return found_device_with_compute; } -bool VulkanLoader::init() { +bool VulkanLoader::init(PFN_vkGetInstanceProcAddr get_proc_addr) { std::call_once(init_flag_, [&]() { if (initialized) { return; } + // (penguinliong) So that MoltenVK instances can be imported. + if (get_proc_addr != nullptr) { + volkInitializeCustom(get_proc_addr); + initialized = check_vulkan_device(); + return; + } #if defined(__APPLE__) vulkan_rt_ = std::make_unique(runtime_lib_dir() + "/libMoltenVK.dylib"); diff --git a/taichi/rhi/vulkan/vulkan_loader.h b/taichi/rhi/vulkan/vulkan_loader.h index e4c95348a3633..32467912bc9ca 100644 --- a/taichi/rhi/vulkan/vulkan_loader.h +++ b/taichi/rhi/vulkan/vulkan_loader.h @@ -25,7 +25,7 @@ class TI_DLL_EXPORT VulkanLoader { void load_instance(VkInstance instance_); void load_device(VkDevice device_); - bool init(); + bool init(PFN_vkGetInstanceProcAddr get_proc_addr = nullptr); PFN_vkVoidFunction load_function(const char *name); VkInstance get_instance() { return vulkan_instance_;