Skip to content

Commit

Permalink
[dx12] Add ti.dx12. (#6174)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
python3kgae and pre-commit-ci[bot] authored Oct 8, 2022
1 parent ed12813 commit 42d4087
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 24 deletions.
20 changes: 13 additions & 7 deletions python/taichi/lang/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,14 @@
"""
# ----------------------

gpu = [cuda, metal, vulkan, opengl, dx11]
dx12 = _ti_core.dx12
"""The DX11 backend.
"""
# ----------------------

gpu = [cuda, metal, vulkan, opengl, dx11, dx12]
"""A list of GPU backends supported on the current system.
Currently contains 'cuda', 'metal', 'opengl', 'vulkan', 'dx11'.
Currently contains 'cuda', 'metal', 'opengl', 'vulkan', 'dx11', 'dx12'.
When this is used, Taichi automatically picks the matching GPU backend. If no
GPU is detected, Taichi falls back to the CPU backend.
Expand Down Expand Up @@ -726,6 +731,7 @@ def is_arch_supported(arch, use_gles=False):
cc: _ti_core.with_cc,
vulkan: _ti_core.with_vulkan,
dx11: _ti_core.with_dx11,
dx12: _ti_core.with_dx12,
wasm: lambda: True,
cpu: lambda: True,
}
Expand Down Expand Up @@ -765,9 +771,9 @@ def get_compute_stream_device_time_elapsed_us() -> float:

__all__ = [
'i', 'ij', 'ijk', 'ijkl', 'ijl', 'ik', 'ikl', 'il', 'j', 'jk', 'jkl', 'jl',
'k', 'kl', 'l', 'x86_64', 'x64', 'dx11', 'wasm', 'arm64', 'cc', 'cpu',
'cuda', 'gpu', 'metal', 'opengl', 'vulkan', 'extension', 'loop_config',
'global_thread_idx', 'assume_in_range', 'block_local', 'cache_read_only',
'init', 'mesh_local', 'no_activate', 'reset', 'mesh_patch_idx',
'get_compute_stream_device_time_elapsed_us'
'k', 'kl', 'l', 'x86_64', 'x64', 'dx11', 'dx12', 'wasm', 'arm64', 'cc',
'cpu', 'cuda', 'gpu', 'metal', 'opengl', 'vulkan', 'extension',
'loop_config', 'global_thread_idx', 'assume_in_range', 'block_local',
'cache_read_only', 'init', 'mesh_local', 'no_activate', 'reset',
'mesh_patch_idx', 'get_compute_stream_device_time_elapsed_us'
]
9 changes: 9 additions & 0 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#if defined(TI_WITH_CUDA)
#include "taichi/codegen/cuda/codegen_cuda.h"
#endif
#if defined(TI_WITH_DX12)
#include "taichi/codegen/dx12/codegen_dx12.h"
#endif
#include "taichi/system/timer.h"
#include "taichi/ir/analysis.h"
#include "taichi/ir/transforms.h"
Expand Down Expand Up @@ -47,6 +50,12 @@ std::unique_ptr<KernelCodeGen> KernelCodeGen::create(Arch arch,
return std::make_unique<KernelCodeGenCUDA>(kernel, stmt);
#else
TI_NOT_IMPLEMENTED
#endif
} else if (arch == Arch::dx12) {
#if defined(TI_WITH_DX12)
return std::make_unique<KernelCodeGenDX12>(kernel, stmt);
#else
TI_NOT_IMPLEMENTED
#endif
} else {
TI_NOT_IMPLEMENTED
Expand Down
9 changes: 9 additions & 0 deletions taichi/python/export_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
#include "taichi/rhi/opengl/opengl_api.h"
#endif

#ifdef TI_WITH_DX12
#include "taichi/rhi/dx12/dx12_api.h"
#endif

#ifdef TI_WITH_CC
namespace taichi::lang::cccp {
extern bool is_c_backend_available();
Expand Down Expand Up @@ -163,6 +167,11 @@ void export_misc(py::module &m) {
#else
m.def("with_dx11", []() { return false; });
#endif
#ifdef TI_WITH_DX12
m.def("with_dx12", taichi::lang::directx12::is_dx12_api_available);
#else
m.def("with_dx12", []() { return false; });
#endif

#ifdef TI_WITH_CC
m.def("with_cc", taichi::lang::cccp::is_c_backend_available);
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/llvm/llvm_runtime_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ LlvmRuntimeExecutor::LlvmRuntimeExecutor(CompileConfig &config,
llvm_context_device_ =
std::make_unique<TaichiLLVMContext>(config_, Arch::dx12);
// FIXME: add dx12 JIT.
// llvm_context_device_->init_runtime_jit_module();
llvm_context_device_->init_runtime_jit_module();
}
#endif

Expand Down
17 changes: 17 additions & 0 deletions tests/cpp/aot/llvm/kernel_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "taichi/system/memory_pool.h"
#include "taichi/runtime/cpu/aot_module_loader_impl.h"
#include "taichi/runtime/cuda/aot_module_loader_impl.h"
#include "taichi/runtime/dx12/aot_module_loader_impl.h"
#include "taichi/rhi/cuda/cuda_driver.h"
#include "taichi/platform/cuda/detect_cuda.h"

Expand Down Expand Up @@ -101,4 +102,20 @@ TEST(LlvmAotTest, CudaKernel) {
}
}

#ifdef TI_WITH_DX12
TEST(LlvmAotTest, DX12Kernel) {
directx12::AotModuleParams aot_params;
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir;
aot_params.module_path = aot_mod_ss.str();
// FIXME: add executor.
auto mod = directx12::make_aot_module(aot_params, Arch::dx12);
auto *k_run = mod->get_kernel("run");
EXPECT_TRUE(k_run);
// FIXME: launch the kernel and check result.
}
#endif

} // namespace taichi::lang
2 changes: 2 additions & 0 deletions tests/cpp/aot/python_scripts/kernel_aot_test1.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,7 @@ def run(base: int, arr: ti.types.ndarray()):
compile_kernel_aot_test1(arch=ti.vulkan)
elif args.arch == "opengl":
compile_kernel_aot_test1(arch=ti.opengl)
elif args.arch == "dx12":
compile_kernel_aot_test1(arch=ti.dx12)
else:
assert False
33 changes: 17 additions & 16 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,23 @@ def _get_expected_matrix_apis():
'atomic_sub', 'atomic_xor', 'axes', 'bit_cast', 'bit_shr', 'block_local',
'cache_read_only', 'cast', 'cc', 'ceil', 'cos', 'cpu', 'cuda',
'data_oriented', 'dataclass', 'deactivate', 'deactivate_all_snodes',
'dx11', 'eig', 'exp', 'experimental', 'extension', 'f16', 'f32', 'f64',
'field', 'float16', 'float32', 'float64', 'floor', 'func', 'get_addr',
'get_compute_stream_device_time_elapsed_us', 'global_thread_idx', 'gpu',
'graph', 'grouped', 'hex_to_rgb', 'i', 'i16', 'i32', 'i64', 'i8', 'ij',
'ijk', 'ijkl', 'ijl', 'ik', 'ikl', 'il', 'init', 'int16', 'int32', 'int64',
'int8', 'is_active', 'is_logging_effective', 'j', 'jk', 'jkl', 'jl', 'k',
'kernel', 'kl', 'l', 'lang', 'length', 'linalg', 'log', 'loop_config',
'math', 'max', 'mesh_local', 'mesh_patch_idx', 'metal', 'min', 'ndarray',
'ndrange', 'no_activate', 'one', 'opengl', 'polar_decompose', 'pow',
'profiler', 'randn', 'random', 'raw_div', 'raw_mod', 'ref',
'rescale_index', 'reset', 'rgb_to_hex', 'root', 'round', 'rsqrt', 'select',
'set_logging_level', 'simt', 'sin', 'solve', 'sparse_matrix_builder',
'sqrt', 'static', 'static_assert', 'static_print', 'stop_grad', 'svd',
'swizzle_generator', 'sym_eig', 'sync', 'tan', 'tanh', 'template', 'tools',
'types', 'u16', 'u32', 'u64', 'u8', 'ui', 'uint16', 'uint32', 'uint64',
'uint8', 'vulkan', 'wasm', 'x64', 'x86_64', 'zero'
'dx11', 'dx12', 'eig', 'exp', 'experimental', 'extension', 'f16', 'f32',
'f64', 'field', 'float16', 'float32', 'float64', 'floor', 'func',
'get_addr', 'get_compute_stream_device_time_elapsed_us',
'global_thread_idx', 'gpu', 'graph', 'grouped', 'hex_to_rgb', 'i', 'i16',
'i32', 'i64', 'i8', 'ij', 'ijk', 'ijkl', 'ijl', 'ik', 'ikl', 'il', 'init',
'int16', 'int32', 'int64', 'int8', 'is_active', 'is_logging_effective',
'j', 'jk', 'jkl', 'jl', 'k', 'kernel', 'kl', 'l', 'lang', 'length',
'linalg', 'log', 'loop_config', 'math', 'max', 'mesh_local',
'mesh_patch_idx', 'metal', 'min', 'ndarray', 'ndrange', 'no_activate',
'one', 'opengl', 'polar_decompose', 'pow', 'profiler', 'randn', 'random',
'raw_div', 'raw_mod', 'ref', 'rescale_index', 'reset', 'rgb_to_hex',
'root', 'round', 'rsqrt', 'select', 'set_logging_level', 'simt', 'sin',
'solve', 'sparse_matrix_builder', 'sqrt', 'static', 'static_assert',
'static_print', 'stop_grad', 'svd', 'swizzle_generator', 'sym_eig', 'sync',
'tan', 'tanh', 'template', 'tools', 'types', 'u16', 'u32', 'u64', 'u8',
'ui', 'uint16', 'uint32', 'uint64', 'uint8', 'vulkan', 'wasm', 'x64',
'x86_64', 'zero'
]
user_api[ti.ad] = [
'FwdMode', 'Tape', 'clear_all_gradients', 'grad_for', 'grad_replaced',
Expand Down
4 changes: 4 additions & 0 deletions tests/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
["cpp", "aot", "python_scripts", "kernel_aot_test1.py"],
"--arch=cuda"
],
"LlvmAotTest.DX12Kernel": [
["cpp", "aot", "python_scripts", "kernel_aot_test1.py"],
"--arch=dx12"
],
"LlvmAotTest.CpuField": [
["cpp", "aot", "python_scripts", "field_aot_test.py"],
"--arch=cpu"
Expand Down

0 comments on commit 42d4087

Please sign in to comment.