Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] [refactor] Decouple quant from SNode 5/n: Rewrite load_quant_float() without SNode #5422

Merged
merged 2 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1435,9 +1435,10 @@ void CodeGenLLVM::create_global_load(GlobalLoadStmt *stmt,
load_quant_fixed(ptr, qfxt, physical_type, should_cache_as_read_only);
} else {
TI_ASSERT(val_type->is<QuantFloatType>());
TI_ASSERT(get_ch->input_snode->dt->is<BitStructType>());
llvm_val[stmt] = load_quant_float(
ptr, get_ch->output_snode, val_type->as<QuantFloatType>(),
physical_type, should_cache_as_read_only);
ptr, get_ch->input_snode->dt->as<BitStructType>(),
get_ch->output_snode->id_in_bit_struct, should_cache_as_read_only);
}
} else {
// Byte pointer case.
Expand Down
11 changes: 5 additions & 6 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,13 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *reconstruct_quant_fixed(llvm::Value *digits,
QuantFixedType *qfxt);

llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr,
SNode *digits_snode,
QuantFloatType *qflt,
Type *physical_type,
llvm::Value *load_quant_float(llvm::Value *digits_ptr,
BitStructType *bit_struct,
int digits_id,
bool should_cache_as_read_only);

llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
llvm::Value *load_quant_float(llvm::Value *digits_ptr,
llvm::Value *exponent_ptr,
QuantFloatType *qflt,
Type *physical_type,
bool should_cache_as_read_only,
Expand Down
37 changes: 19 additions & 18 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,32 +532,33 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits,
return builder->CreateFMul(cast, s);
}

llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_bit_ptr,
SNode *digits_snode,
QuantFloatType *qflt,
Type *physical_type,
llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr,
BitStructType *bit_struct,
int digits_id,
bool should_cache_as_read_only) {
auto exponent_snode = digits_snode->exp_snode;
// Compute the bit pointer of the exponent bits.
TI_ASSERT(digits_snode->parent == exponent_snode->parent);
auto exponent_bit_ptr = offset_bit_ptr(
digits_bit_ptr, exponent_snode->bit_offset - digits_snode->bit_offset);
return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt, physical_type,
should_cache_as_read_only,
digits_snode->owns_shared_exponent);
auto exponent_id = bit_struct->get_member_exponent(digits_id);
auto exponent_bit_offset = bit_struct->get_member_bit_offset(exponent_id);
auto digits_bit_offset = bit_struct->get_member_bit_offset(digits_id);
auto bit_offset_delta = exponent_bit_offset - digits_bit_offset;
auto exponent_ptr = offset_bit_ptr(digits_ptr, bit_offset_delta);
auto qflt = bit_struct->get_member_type(digits_id)->as<QuantFloatType>();
auto physical_type = bit_struct->get_physical_type();
auto shared_exponent = bit_struct->get_member_owns_shared_exponent(digits_id);
return load_quant_float(digits_ptr, exponent_ptr, qflt, physical_type,
should_cache_as_read_only, shared_exponent);
}

llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr,
llvm::Value *exponent_ptr,
QuantFloatType *qflt,
Type *physical_type,
bool should_cache_as_read_only,
bool shared_exponent) {
auto digits = load_quant_int(digits_bit_ptr,
qflt->get_digits_type()->as<QuantIntType>(),
physical_type, should_cache_as_read_only);
auto digits =
load_quant_int(digits_ptr, qflt->get_digits_type()->as<QuantIntType>(),
physical_type, should_cache_as_read_only);
auto exponent_val = load_quant_int(
exponent_bit_ptr, qflt->get_exponent_type()->as<QuantIntType>(),
exponent_ptr, qflt->get_exponent_type()->as<QuantIntType>(),
physical_type, should_cache_as_read_only);
return reconstruct_quant_float(digits, exponent_val, qflt, shared_exponent);
}
Expand Down