diff --git a/taichi/codegen/dx12/CMakeLists.txt b/taichi/codegen/dx12/CMakeLists.txt index 24dca8d27238b..a22c3e52d91e3 100644 --- a/taichi/codegen/dx12/CMakeLists.txt +++ b/taichi/codegen/dx12/CMakeLists.txt @@ -5,6 +5,8 @@ target_sources(dx12_codegen PRIVATE codegen_dx12.cpp dx12_global_optimize_module.cpp + dx12_lower_intrinsic.cpp + dx12_lower_runtime_context.cpp ) target_include_directories(dx12_codegen diff --git a/taichi/codegen/dx12/dx12_global_optimize_module.cpp b/taichi/codegen/dx12/dx12_global_optimize_module.cpp index 94a9d7c003328..316a295d5aea6 100644 --- a/taichi/codegen/dx12/dx12_global_optimize_module.cpp +++ b/taichi/codegen/dx12/dx12_global_optimize_module.cpp @@ -8,6 +8,7 @@ #include "taichi/util/file_sequence_writer.h" #include "taichi/runtime/llvm/llvm_context.h" +#include "dx12_llvm_passes.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/IR/Function.h" @@ -38,6 +39,8 @@ namespace taichi { namespace lang { namespace directx12 { +const char *NumWorkGroupsCBName = "num_work_groups.cbuf"; + const llvm::StringRef ShaderAttrKindStr = "hlsl.shader"; void mark_function_as_cs_entry(::llvm::Function *F) { @@ -53,6 +56,16 @@ void set_num_threads(llvm::Function *F, unsigned x, unsigned y, unsigned z) { F->addFnAttr(NumThreadsAttrKindStr, Str); } +GlobalVariable *createGlobalVariableForResource(Module &M, + const char *Name, + llvm::Type *Ty) { + auto *GV = new GlobalVariable(M, Ty, /*isConstant*/ false, + GlobalValue::LinkageTypes::ExternalLinkage, + /*Initializer*/ nullptr, Name); + GV->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None); + return GV; +} + std::vector global_optimize_module(llvm::Module *module, CompileConfig &config) { TI_AUTO_PROF @@ -104,6 +117,9 @@ std::vector global_optimize_module(llvm::Module *module, module->setDataLayout(target_machine->createDataLayout()); + // Lower taichi intrinsic first. + module_pass_manager.add(createTaichiIntrinsicLowerPass(&config)); + module_pass_manager.add(createTargetTransformInfoWrapperPass( target_machine->getTargetIRAnalysis())); function_pass_manager.add(createTargetTransformInfoWrapperPass( @@ -119,6 +135,9 @@ std::vector global_optimize_module(llvm::Module *module, b.populateFunctionPassManager(function_pass_manager); b.populateModulePassManager(module_pass_manager); + // Add passes after inline. + module_pass_manager.add(createTaichiRuntimeContextLowerPass()); + llvm::SmallString<256> str; llvm::raw_svector_ostream OS(str); // Write DXIL container to OS. diff --git a/taichi/codegen/dx12/dx12_llvm_passes.h b/taichi/codegen/dx12/dx12_llvm_passes.h index c07896abba1a3..2821cead303d3 100644 --- a/taichi/codegen/dx12/dx12_llvm_passes.h +++ b/taichi/codegen/dx12/dx12_llvm_passes.h @@ -7,6 +7,8 @@ namespace llvm { class Function; class Module; +class Type; +class GlobalVariable; } // namespace llvm namespace taichi { @@ -18,6 +20,9 @@ namespace directx12 { void mark_function_as_cs_entry(llvm::Function *); bool is_cs_entry(llvm::Function *); void set_num_threads(llvm::Function *, unsigned x, unsigned y, unsigned z); +llvm::GlobalVariable *createGlobalVariableForResource(llvm::Module &M, + const char *Name, + llvm::Type *Ty); std::vector global_optimize_module(llvm::Module *module, CompileConfig &config); @@ -27,3 +32,22 @@ extern const char *NumWorkGroupsCBName; } // namespace directx12 } // namespace lang } // namespace taichi + +namespace llvm { +class ModulePass; +class PassRegistry; +class Function; + +/// Initializer for DXIL-prepare +void initializeTaichiRuntimeContextLowerPass(PassRegistry &); + +/// Pass to convert modules into DXIL-compatable modules +ModulePass *createTaichiRuntimeContextLowerPass(); + +/// Initializer for taichi intrinsic lower. +void initializeTaichiIntrinsicLowerPass(PassRegistry &); + +/// Pass to lower taichi intrinsic into DXIL intrinsic. +ModulePass *createTaichiIntrinsicLowerPass(taichi::lang::CompileConfig *config); + +} // namespace llvm diff --git a/taichi/codegen/dx12/dx12_lower_intrinsic.cpp b/taichi/codegen/dx12/dx12_lower_intrinsic.cpp new file mode 100644 index 0000000000000..2a694ca04af49 --- /dev/null +++ b/taichi/codegen/dx12/dx12_lower_intrinsic.cpp @@ -0,0 +1,121 @@ + +#include "dx12_llvm_passes.h" +#include "llvm/Pass.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Instructions.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/IntrinsicsDirectX.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +#include "taichi/program/compile_config.h" +#include "taichi/runtime/llvm/llvm_context.h" + +using namespace llvm; +using namespace taichi::lang::directx12; + +#define DEBUG_TYPE "dxil-taichi-intrinsic-lower" + +namespace { + +class TaichiIntrinsicLower : public ModulePass { + public: + bool runOnModule(Module &M) override { + auto &Ctx = M.getContext(); + // patch intrinsic + auto patch_intrinsic = [&](std::string name, Intrinsic::ID intrin, + bool ret = true, + std::vector types = {}, + std::vector extra_args = {}) { + auto func = M.getFunction(name); + if (!func) { + return; + } + func->deleteBody(); + auto bb = llvm::BasicBlock::Create(Ctx, "entry", func); + IRBuilder<> builder(Ctx); + builder.SetInsertPoint(bb); + std::vector args; + for (auto &arg : func->args()) + args.push_back(&arg); + args.insert(args.end(), extra_args.begin(), extra_args.end()); + if (ret) { + builder.CreateRet(builder.CreateIntrinsic(intrin, types, args)); + } else { + builder.CreateIntrinsic(intrin, types, args); + builder.CreateRetVoid(); + } + func->setLinkage(GlobalValue::LinkageTypes::InternalLinkage); + taichi::lang::TaichiLLVMContext::mark_inline(func); + }; + + llvm::IRBuilder<> B(Ctx); + Value *i32Zero = B.getInt32(0); + + auto patch_intrinsic_to_const = [&](std::string name, Constant *C, + Type *Ty) { + auto func = M.getFunction(name); + if (!func) { + return; + } + func->deleteBody(); + auto bb = llvm::BasicBlock::Create(Ctx, "entry", func); + IRBuilder<> B(Ctx); + B.SetInsertPoint(bb); + Value *V = C; + if (V->getType()->isPointerTy()) + V = B.CreateLoad(Ty, C); + B.CreateRet(V); + func->setLinkage(GlobalValue::LinkageTypes::InternalLinkage); + taichi::lang::TaichiLLVMContext::mark_inline(func); + }; + // group thread id. + patch_intrinsic("thread_idx", Intrinsic::dx_thread_id_in_group, true, {}, + {i32Zero}); + // group idx. + patch_intrinsic("block_idx", Intrinsic::dx_group_id, true, {}, {i32Zero}); + // Group Size + unsigned group_size = 64; + if (config) + group_size = config->default_gpu_block_dim; + + auto *I32Ty = B.getInt32Ty(); + Constant *block_dim = B.getInt32(group_size); + patch_intrinsic_to_const("block_dim", block_dim, I32Ty); + // Num work groups will be in a special CBuffer. + // TaichiRuntimeContextLower pass will place the CBuffer to special binding + // space. + Type *TyNumWorkGroups = FixedVectorType::get(I32Ty, 3); + Constant *CBNumWorkGroups = createGlobalVariableForResource( + M, NumWorkGroupsCBName, TyNumWorkGroups); + + Constant *NumWorkGroupX = cast( + B.CreateConstGEP2_32(TyNumWorkGroups, CBNumWorkGroups, 0, 0)); + patch_intrinsic_to_const("grid_dim", NumWorkGroupX, I32Ty); + return true; + } + + TaichiIntrinsicLower(taichi::lang::CompileConfig *config = nullptr) + : ModulePass(ID), config(config) { + initializeTaichiIntrinsicLowerPass(*PassRegistry::getPassRegistry()); + } + + static char ID; // Pass identification. + private: + taichi::lang::CompileConfig *config; +}; +char TaichiIntrinsicLower::ID = 0; + +} // end anonymous namespace + +INITIALIZE_PASS(TaichiIntrinsicLower, + DEBUG_TYPE, + "Lower taichi intrinsic", + false, + false) + +llvm::ModulePass *llvm::createTaichiIntrinsicLowerPass( + taichi::lang::CompileConfig *config) { + return new TaichiIntrinsicLower(config); +} diff --git a/taichi/codegen/dx12/dx12_lower_runtime_context.cpp b/taichi/codegen/dx12/dx12_lower_runtime_context.cpp new file mode 100644 index 0000000000000..26884557e2659 --- /dev/null +++ b/taichi/codegen/dx12/dx12_lower_runtime_context.cpp @@ -0,0 +1,49 @@ + + +#include "dx12_llvm_passes.h" + +#include "llvm/Pass.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Instructions.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +#include "taichi/program/compile_config.h" +#include "taichi/runtime/llvm/llvm_context.h" + +using namespace llvm; +using namespace taichi::lang::directx12; + +#define DEBUG_TYPE "dxil-taichi-runtime-context-lower" + +namespace { + +class TaichiRuntimeContextLower : public ModulePass { + public: + bool runOnModule(Module &M) override { + // TODO: lower taichi RuntimeContext into DXIL resources. + return true; + } + + TaichiRuntimeContextLower() : ModulePass(ID) { + initializeTaichiRuntimeContextLowerPass(*PassRegistry::getPassRegistry()); + } + + static char ID; // Pass identification. + private: +}; +char TaichiRuntimeContextLower::ID = 0; + +} // end anonymous namespace + +INITIALIZE_PASS(TaichiRuntimeContextLower, + DEBUG_TYPE, + "Lower taichi RuntimeContext", + false, + false) + +llvm::ModulePass *llvm::createTaichiRuntimeContextLowerPass() { + return new TaichiRuntimeContextLower(); +}