Skip to content

Commit

Permalink
[dx12] Drop code for llvm passes which prepare for DXIL generation. (#…
Browse files Browse the repository at this point in the history
…5998)

2 passes are added for DXIL generation.

TaichiIntrinsicLower will translate taichi intrinsic like thread_idx
into the form DirectX backend expected.

TaichiRuntimeContextLower will translate the TaichiRuntimeContext
parameter for kernel into Buffers/ConstantBuffers.
TaichiRuntimeContextLower is empty now.
It is added after inline so optimizations reduce the load/store on temp
ptr. And it is easier to know a store is on the TaichiRuntimeContext.

Related issue = #5276

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
python3kgae and pre-commit-ci[bot] authored Sep 15, 2022
1 parent a9f2905 commit e7bdbff
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 0 deletions.
2 changes: 2 additions & 0 deletions taichi/codegen/dx12/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ target_sources(dx12_codegen
PRIVATE
codegen_dx12.cpp
dx12_global_optimize_module.cpp
dx12_lower_intrinsic.cpp
dx12_lower_runtime_context.cpp
)

target_include_directories(dx12_codegen
Expand Down
19 changes: 19 additions & 0 deletions taichi/codegen/dx12/dx12_global_optimize_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "taichi/util/file_sequence_writer.h"
#include "taichi/runtime/llvm/llvm_context.h"

#include "dx12_llvm_passes.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Function.h"
Expand Down Expand Up @@ -38,6 +39,8 @@ namespace taichi {
namespace lang {
namespace directx12 {

const char *NumWorkGroupsCBName = "num_work_groups.cbuf";

const llvm::StringRef ShaderAttrKindStr = "hlsl.shader";

void mark_function_as_cs_entry(::llvm::Function *F) {
Expand All @@ -53,6 +56,16 @@ void set_num_threads(llvm::Function *F, unsigned x, unsigned y, unsigned z) {
F->addFnAttr(NumThreadsAttrKindStr, Str);
}

GlobalVariable *createGlobalVariableForResource(Module &M,
const char *Name,
llvm::Type *Ty) {
auto *GV = new GlobalVariable(M, Ty, /*isConstant*/ false,
GlobalValue::LinkageTypes::ExternalLinkage,
/*Initializer*/ nullptr, Name);
GV->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
return GV;
}

std::vector<uint8_t> global_optimize_module(llvm::Module *module,
CompileConfig &config) {
TI_AUTO_PROF
Expand Down Expand Up @@ -104,6 +117,9 @@ std::vector<uint8_t> global_optimize_module(llvm::Module *module,

module->setDataLayout(target_machine->createDataLayout());

// Lower taichi intrinsic first.
module_pass_manager.add(createTaichiIntrinsicLowerPass(&config));

module_pass_manager.add(createTargetTransformInfoWrapperPass(
target_machine->getTargetIRAnalysis()));
function_pass_manager.add(createTargetTransformInfoWrapperPass(
Expand All @@ -119,6 +135,9 @@ std::vector<uint8_t> global_optimize_module(llvm::Module *module,

b.populateFunctionPassManager(function_pass_manager);
b.populateModulePassManager(module_pass_manager);
// Add passes after inline.
module_pass_manager.add(createTaichiRuntimeContextLowerPass());

llvm::SmallString<256> str;
llvm::raw_svector_ostream OS(str);
// Write DXIL container to OS.
Expand Down
24 changes: 24 additions & 0 deletions taichi/codegen/dx12/dx12_llvm_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
namespace llvm {
class Function;
class Module;
class Type;
class GlobalVariable;
} // namespace llvm

namespace taichi {
Expand All @@ -18,6 +20,9 @@ namespace directx12 {
void mark_function_as_cs_entry(llvm::Function *);
bool is_cs_entry(llvm::Function *);
void set_num_threads(llvm::Function *, unsigned x, unsigned y, unsigned z);
llvm::GlobalVariable *createGlobalVariableForResource(llvm::Module &M,
const char *Name,
llvm::Type *Ty);

std::vector<uint8_t> global_optimize_module(llvm::Module *module,
CompileConfig &config);
Expand All @@ -27,3 +32,22 @@ extern const char *NumWorkGroupsCBName;
} // namespace directx12
} // namespace lang
} // namespace taichi

namespace llvm {
class ModulePass;
class PassRegistry;
class Function;

/// Initializer for DXIL-prepare
void initializeTaichiRuntimeContextLowerPass(PassRegistry &);

/// Pass to convert modules into DXIL-compatable modules
ModulePass *createTaichiRuntimeContextLowerPass();

/// Initializer for taichi intrinsic lower.
void initializeTaichiIntrinsicLowerPass(PassRegistry &);

/// Pass to lower taichi intrinsic into DXIL intrinsic.
ModulePass *createTaichiIntrinsicLowerPass(taichi::lang::CompileConfig *config);

} // namespace llvm
121 changes: 121 additions & 0 deletions taichi/codegen/dx12/dx12_lower_intrinsic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@

#include "dx12_llvm_passes.h"
#include "llvm/Pass.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Instructions.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"

#include "taichi/program/compile_config.h"
#include "taichi/runtime/llvm/llvm_context.h"

using namespace llvm;
using namespace taichi::lang::directx12;

#define DEBUG_TYPE "dxil-taichi-intrinsic-lower"

namespace {

class TaichiIntrinsicLower : public ModulePass {
public:
bool runOnModule(Module &M) override {
auto &Ctx = M.getContext();
// patch intrinsic
auto patch_intrinsic = [&](std::string name, Intrinsic::ID intrin,
bool ret = true,
std::vector<llvm::Type *> types = {},
std::vector<llvm::Value *> extra_args = {}) {
auto func = M.getFunction(name);
if (!func) {
return;
}
func->deleteBody();
auto bb = llvm::BasicBlock::Create(Ctx, "entry", func);
IRBuilder<> builder(Ctx);
builder.SetInsertPoint(bb);
std::vector<llvm::Value *> args;
for (auto &arg : func->args())
args.push_back(&arg);
args.insert(args.end(), extra_args.begin(), extra_args.end());
if (ret) {
builder.CreateRet(builder.CreateIntrinsic(intrin, types, args));
} else {
builder.CreateIntrinsic(intrin, types, args);
builder.CreateRetVoid();
}
func->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
taichi::lang::TaichiLLVMContext::mark_inline(func);
};

llvm::IRBuilder<> B(Ctx);
Value *i32Zero = B.getInt32(0);

auto patch_intrinsic_to_const = [&](std::string name, Constant *C,
Type *Ty) {
auto func = M.getFunction(name);
if (!func) {
return;
}
func->deleteBody();
auto bb = llvm::BasicBlock::Create(Ctx, "entry", func);
IRBuilder<> B(Ctx);
B.SetInsertPoint(bb);
Value *V = C;
if (V->getType()->isPointerTy())
V = B.CreateLoad(Ty, C);
B.CreateRet(V);
func->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
taichi::lang::TaichiLLVMContext::mark_inline(func);
};
// group thread id.
patch_intrinsic("thread_idx", Intrinsic::dx_thread_id_in_group, true, {},
{i32Zero});
// group idx.
patch_intrinsic("block_idx", Intrinsic::dx_group_id, true, {}, {i32Zero});
// Group Size
unsigned group_size = 64;
if (config)
group_size = config->default_gpu_block_dim;

auto *I32Ty = B.getInt32Ty();
Constant *block_dim = B.getInt32(group_size);
patch_intrinsic_to_const("block_dim", block_dim, I32Ty);
// Num work groups will be in a special CBuffer.
// TaichiRuntimeContextLower pass will place the CBuffer to special binding
// space.
Type *TyNumWorkGroups = FixedVectorType::get(I32Ty, 3);
Constant *CBNumWorkGroups = createGlobalVariableForResource(
M, NumWorkGroupsCBName, TyNumWorkGroups);

Constant *NumWorkGroupX = cast<Constant>(
B.CreateConstGEP2_32(TyNumWorkGroups, CBNumWorkGroups, 0, 0));
patch_intrinsic_to_const("grid_dim", NumWorkGroupX, I32Ty);
return true;
}

TaichiIntrinsicLower(taichi::lang::CompileConfig *config = nullptr)
: ModulePass(ID), config(config) {
initializeTaichiIntrinsicLowerPass(*PassRegistry::getPassRegistry());
}

static char ID; // Pass identification.
private:
taichi::lang::CompileConfig *config;
};
char TaichiIntrinsicLower::ID = 0;

} // end anonymous namespace

INITIALIZE_PASS(TaichiIntrinsicLower,
DEBUG_TYPE,
"Lower taichi intrinsic",
false,
false)

llvm::ModulePass *llvm::createTaichiIntrinsicLowerPass(
taichi::lang::CompileConfig *config) {
return new TaichiIntrinsicLower(config);
}
49 changes: 49 additions & 0 deletions taichi/codegen/dx12/dx12_lower_runtime_context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@


#include "dx12_llvm_passes.h"

#include "llvm/Pass.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Instructions.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"

#include "taichi/program/compile_config.h"
#include "taichi/runtime/llvm/llvm_context.h"

using namespace llvm;
using namespace taichi::lang::directx12;

#define DEBUG_TYPE "dxil-taichi-runtime-context-lower"

namespace {

class TaichiRuntimeContextLower : public ModulePass {
public:
bool runOnModule(Module &M) override {
// TODO: lower taichi RuntimeContext into DXIL resources.
return true;
}

TaichiRuntimeContextLower() : ModulePass(ID) {
initializeTaichiRuntimeContextLowerPass(*PassRegistry::getPassRegistry());
}

static char ID; // Pass identification.
private:
};
char TaichiRuntimeContextLower::ID = 0;

} // end anonymous namespace

INITIALIZE_PASS(TaichiRuntimeContextLower,
DEBUG_TYPE,
"Lower taichi RuntimeContext",
false,
false)

llvm::ModulePass *llvm::createTaichiRuntimeContextLowerPass() {
return new TaichiRuntimeContextLower();
}

0 comments on commit e7bdbff

Please sign in to comment.