From 504671614f168d1fe16ef13adf032ce91fd92bfc Mon Sep 17 00:00:00 2001 From: PENGUINLIONG Date: Wed, 14 Sep 2022 19:40:15 +0800 Subject: [PATCH] [aot] Dump required device capability in AOT module meta (#6056) --- .github/workflows/scripts/aot-demo.sh | 2 +- taichi/aot/module_data.h | 4 +++- taichi/rhi/device.h | 5 +++++ taichi/runtime/gfx/aot_module_builder_impl.cpp | 2 ++ taichi/runtime/gfx/aot_utils.h | 4 +++- 5 files changed, 14 insertions(+), 3 deletions(-) diff --git a/.github/workflows/scripts/aot-demo.sh b/.github/workflows/scripts/aot-demo.sh index aeff7dac6dfd3..9ee8c0b127dd2 100755 --- a/.github/workflows/scripts/aot-demo.sh +++ b/.github/workflows/scripts/aot-demo.sh @@ -48,7 +48,7 @@ function prepare-unity-build-env { cd taichi # Dependencies - git clone --reference-if-able /var/lib/git-cache https://github.com/taichi-dev/Taichi-UnityExample + git clone --reference-if-able /var/lib/git-cache -b upgrade-modules1 https://github.com/taichi-dev/Taichi-UnityExample python misc/generate_unity_language_binding.py cp c_api/unity/*.cs Taichi-UnityExample/Assets/Taichi/Generated diff --git a/taichi/aot/module_data.h b/taichi/aot/module_data.h index f0c6ea740ae06..824de73fe6283 100644 --- a/taichi/aot/module_data.h +++ b/taichi/aot/module_data.h @@ -3,6 +3,7 @@ #include #include +#include "taichi/rhi/device.h" #include "taichi/common/core.h" #include "taichi/common/serialization.h" @@ -120,6 +121,7 @@ struct ModuleData { std::unordered_map kernels; std::unordered_map kernel_tmpls; std::vector fields; + std::map required_caps; size_t root_buffer_size; @@ -129,7 +131,7 @@ struct ModuleData { ts.write_to_file(path); } - TI_IO_DEF(kernels, kernel_tmpls, fields, root_buffer_size); + TI_IO_DEF(kernels, kernel_tmpls, fields, required_caps, root_buffer_size); }; } // namespace aot diff --git a/taichi/rhi/device.h b/taichi/rhi/device.h index 3e8bce1d35118..46576da09f2aa 100644 --- a/taichi/rhi/device.h +++ b/taichi/rhi/device.h @@ -418,6 +418,11 @@ class Device { dest.set_cap(k, v); } } + void clone_caps(std::map &dest) const { + for (const auto &[k, v] : caps_) { + dest[k] = v; + } + } void print_all_cap() const; diff --git a/taichi/runtime/gfx/aot_module_builder_impl.cpp b/taichi/runtime/gfx/aot_module_builder_impl.cpp index afd6c10e707dc..364a1657dfb91 100644 --- a/taichi/runtime/gfx/aot_module_builder_impl.cpp +++ b/taichi/runtime/gfx/aot_module_builder_impl.cpp @@ -29,6 +29,7 @@ class AotDataConverter { res.kernels[ker.name] = val; } res.fields = in.fields; + res.required_caps = in.required_caps; res.root_buffer_size = in.root_buffer_size; return res; } @@ -110,6 +111,7 @@ AotModuleBuilderImpl::AotModuleBuilderImpl( aot_target_device_ = target_device ? std::move(target_device) : std::make_unique(device_api_backend_); + aot_target_device_->clone_caps(ti_aot_data_.required_caps); if (!compiled_structs.empty()) { ti_aot_data_.root_buffer_size = compiled_structs[0].root_size; } diff --git a/taichi/runtime/gfx/aot_utils.h b/taichi/runtime/gfx/aot_utils.h index e8c1f5b0ea150..2bfdc367f57f1 100644 --- a/taichi/runtime/gfx/aot_utils.h +++ b/taichi/runtime/gfx/aot_utils.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include "taichi/codegen/spirv/kernel_utils.h" #include "taichi/aot/module_loader.h" @@ -17,9 +18,10 @@ struct TaichiAotData { std::vector>> spirv_codes; std::vector kernels; std::vector fields; + std::map required_caps; size_t root_buffer_size{0}; - TI_IO_DEF(kernels, fields, root_buffer_size); + TI_IO_DEF(kernels, fields, required_caps, root_buffer_size); }; } // namespace gfx