Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] [refactor] Decouple quant from SNode 3/n: Extend bit pointers #5232

Merged
merged 2 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,10 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
auto val_type = ptr_type->get_pointee_type();
if (auto qit = val_type->cast<QuantIntType>()) {
dtype = get_ch->input_snode->physical_type;
auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->src]);
auto [data_ptr, bit_offset] = load_bit_ptr(llvm_val[stmt->src]);
data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype));
auto data = create_intrinsic_load(dtype, data_ptr);
llvm_val[stmt] = extract_quant_int(data, bit_offset, qit, dtype);
llvm_val[stmt] = extract_quant_int(data, bit_offset, qit);
} else {
// TODO: support __ldg
TI_ASSERT(val_type->is<QuantFixedType>() ||
Expand Down
94 changes: 36 additions & 58 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1195,11 +1195,9 @@ llvm::Value *CodeGenLLVM::quant_type_atomic(AtomicOpStmt *stmt) {

auto dst_type = stmt->dest->ret_type->as<PointerType>()->get_pointee_type();
if (auto qit = dst_type->cast<QuantIntType>()) {
return atomic_add_quant_int(
stmt, qit, stmt->dest->as<GetChStmt>()->input_snode->physical_type);
return atomic_add_quant_int(stmt, qit);
} else if (auto qfxt = dst_type->cast<QuantFixedType>()) {
return atomic_add_quant_fixed(
stmt, qfxt, stmt->dest->as<GetChStmt>()->input_snode->physical_type);
return atomic_add_quant_fixed(stmt, qfxt);
} else {
return nullptr;
}
Expand Down Expand Up @@ -1354,7 +1352,6 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
}
}
store_quant_int(llvm_val[stmt->dest], pointee_type->as<QuantIntType>(),
stmt->dest->as<GetChStmt>()->input_snode->physical_type,
llvm_val[stmt->val], true);
} else {
builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]);
Expand All @@ -1368,9 +1365,7 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
if (ptr_type->is_bit_pointer()) {
auto val_type = ptr_type->get_pointee_type();
if (auto qit = val_type->cast<QuantIntType>()) {
llvm_val[stmt] = load_quant_int(
llvm_val[stmt->src], qit,
stmt->src->as<GetChStmt>()->input_snode->physical_type);
llvm_val[stmt] = load_quant_int(llvm_val[stmt->src], qit);
} else {
TI_ASSERT(val_type->is<QuantFixedType>() ||
val_type->is<QuantFloatType>());
Expand Down Expand Up @@ -1479,61 +1474,44 @@ void CodeGenLLVM::visit(LinearizeStmt *stmt) {

void CodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED}

llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base,
llvm::Value *bit_offset) {
// 1. get the bit pointer LLVM struct
// struct bit_pointer {
// i8* byte_ptr;
// i32 offset;
llvm::Value *CodeGenLLVM::create_bit_ptr(llvm::Value *byte_ptr,
llvm::Value *bit_offset) {
// 1. define the bit pointer struct (X=8/16/32/64)
// struct bit_pointer_X {
// iX* byte_ptr;
// i32 bit_offset;
// };
TI_ASSERT(bit_offset->getType()->isIntegerTy(32));
auto struct_type = llvm::StructType::get(
strongoier marked this conversation as resolved.
Show resolved Hide resolved
*llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context)});
*llvm_context, {byte_ptr->getType(), bit_offset->getType()});
// 2. allocate the bit pointer struct
auto bit_ptr_struct = create_entry_block_alloca(struct_type);
// 3. store `byte_ptr_base` into `bit_ptr_struct` (if provided)
if (byte_ptr_base) {
auto byte_ptr = builder->CreateBitCast(
byte_ptr_base, llvm::PointerType::getInt8PtrTy(*llvm_context));
builder->CreateStore(
byte_ptr, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0),
tlctx->get_constant(0)}));
}
// 4. store `offset` in `bit_ptr_struct` (if provided)
if (bit_offset) {
builder->CreateStore(
bit_offset,
builder->CreateGEP(bit_ptr_struct,
{tlctx->get_constant(0), tlctx->get_constant(1)}));
}
return bit_ptr_struct;
auto bit_ptr = create_entry_block_alloca(struct_type);
// 3. store `byte_ptr`
builder->CreateStore(
byte_ptr, builder->CreateGEP(
bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
// 4. store `bit_offset
builder->CreateStore(bit_offset,
builder->CreateGEP(bit_ptr, {tlctx->get_constant(0),
tlctx->get_constant(1)}));
return bit_ptr;
}

std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::load_bit_ptr(
llvm::Value *bit_ptr) {
auto byte_ptr = builder->CreateLoad(builder->CreateGEP(
bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
auto bit_offset = builder->CreateLoad(builder->CreateGEP(
bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}));
return std::make_tuple(byte_ptr, bit_offset);
}

llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *input_bit_ptr,
llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *bit_ptr,
int bit_offset_delta) {
auto byte_ptr_base = builder->CreateLoad(builder->CreateGEP(
input_bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
auto input_offset = builder->CreateLoad(builder->CreateGEP(
input_bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}));
auto [byte_ptr, bit_offset] = load_bit_ptr(bit_ptr);
auto new_bit_offset =
builder->CreateAdd(input_offset, tlctx->get_constant(bit_offset_delta));
return create_bit_ptr_struct(byte_ptr_base, new_bit_offset);
}

std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::load_bit_pointer(
llvm::Value *ptr) {
// 1. load byte pointer
auto byte_ptr_in_bit_struct =
builder->CreateGEP(ptr, {tlctx->get_constant(0), tlctx->get_constant(0)});
auto byte_ptr = builder->CreateLoad(byte_ptr_in_bit_struct);
TI_ASSERT(byte_ptr->getType()->getPointerElementType()->isIntegerTy(8));

// 2. load bit offset
auto bit_offset_in_bit_struct =
builder->CreateGEP(ptr, {tlctx->get_constant(0), tlctx->get_constant(1)});
auto bit_offset = builder->CreateLoad(bit_offset_in_bit_struct);
TI_ASSERT(bit_offset->getType()->isIntegerTy(32));
return std::make_tuple(byte_ptr, bit_offset);
builder->CreateAdd(bit_offset, tlctx->get_constant(bit_offset_delta));
return create_bit_ptr(byte_ptr, new_bit_offset);
}

void CodeGenLLVM::visit(SNodeLookupStmt *stmt) {
Expand All @@ -1560,7 +1538,7 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) {
snode->dt->as<BitArrayType>()->get_element_num_bits();
auto offset = tlctx->get_constant(element_num_bits);
offset = builder->CreateMul(offset, llvm_val[stmt->input_index]);
llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_snode], offset);
llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_snode], offset);
} else {
TI_INFO(snode_type_name(snode->type));
TI_NOT_IMPLEMENTED
Expand All @@ -1575,7 +1553,7 @@ void CodeGenLLVM::visit(GetChStmt *stmt) {
auto bit_offset = bit_struct->get_member_bit_offset(
stmt->input_snode->child_id(stmt->output_snode));
auto offset = tlctx->get_constant(bit_offset);
llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset);
llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_ptr], offset);
} else {
auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(),
{builder->CreateBitCast(
Expand Down
32 changes: 7 additions & 25 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,9 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(SNodeOpStmt *stmt) override;

llvm::Value *atomic_add_quant_fixed(AtomicOpStmt *stmt,
QuantFixedType *qfxt,
Type *physical_type);
llvm::Value *atomic_add_quant_fixed(AtomicOpStmt *stmt, QuantFixedType *qfxt);

llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt,
QuantIntType *qit,
Type *physical_type);
llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt, QuantIntType *qit);

llvm::Value *quant_fixed_to_quant_int(QuantFixedType *qfxt,
QuantIntType *qit,
Expand All @@ -252,20 +248,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void store_quant_int(llvm::Value *bit_ptr,
QuantIntType *qit,
Type *physical_type,
llvm::Value *value,
bool atomic);

void store_quant_int(llvm::Value *byte_ptr,
llvm::Value *bit_offset,
QuantIntType *qit,
Type *physical_type,
llvm::Value *value,
bool atomic);

void store_masked(llvm::Value *byte_ptr,
uint64 mask,
Type *physical_type,
llvm::Value *value,
bool atomic);

Expand All @@ -282,22 +269,18 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *extract_quant_float(llvm::Value *local_bit_struct,
SNode *digits_snode);

llvm::Value *load_quant_int(llvm::Value *ptr,
QuantIntType *qit,
Type *physical_type);
llvm::Value *load_quant_int(llvm::Value *ptr, QuantIntType *qit);

llvm::Value *extract_quant_int(llvm::Value *physical_value,
llvm::Value *bit_offset,
QuantIntType *qit,
Type *physical_type);
QuantIntType *qit);

llvm::Value *reconstruct_quant_fixed(llvm::Value *digits,
QuantFixedType *qfxt);

llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
QuantFloatType *qflt,
Type *physical_type,
bool shared_exponent);

llvm::Value *reconstruct_quant_float(llvm::Value *input_digits,
Expand All @@ -319,12 +302,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(IntegerOffsetStmt *stmt) override;

llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base = nullptr,
llvm::Value *bit_offset = nullptr);
llvm::Value *create_bit_ptr(llvm::Value *byte_ptr, llvm::Value *bit_offset);

llvm::Value *offset_bit_ptr(llvm::Value *input_bit_ptr, int bit_offset_delta);
std::tuple<llvm::Value *, llvm::Value *> load_bit_ptr(llvm::Value *bit_ptr);

std::tuple<llvm::Value *, llvm::Value *> load_bit_pointer(llvm::Value *ptr);
llvm::Value *offset_bit_ptr(llvm::Value *bit_ptr, int bit_offset_delta);

void visit(SNodeLookupStmt *stmt) override;

Expand Down
Loading