Skip to content

Commit

Permalink
[type] [refactor] Decouple quant from SNode 7/n: Rewrite BitStructSto…
Browse files Browse the repository at this point in the history
…reStmt codegen without SNode (#5475)

* [type] [refactor] Rewrite BitStructStoreStmt codegen without SNode

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Jul 20, 2022
1 parent 25c0636 commit 2f3ce21
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 108 deletions.
6 changes: 3 additions & 3 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *value,
bool atomic);

void store_masked(llvm::Value *byte_ptr,
llvm::Type *byte_ptr_ty,
void store_masked(llvm::Value *ptr,
llvm::Type *ty,
uint64 mask,
llvm::Value *value,
bool atomic);
Expand All @@ -244,7 +244,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Value *quant_int_or_quant_fixed_to_bits(llvm::Value *val,
Type *input_type,
Type *output_type);
llvm::Type *output_type);

void visit(BitStructStoreStmt *stmt) override;

Expand Down
171 changes: 70 additions & 101 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,25 @@ void CodeGenLLVM::store_quant_fixed(llvm::Value *bit_ptr,
to_quant_fixed(value, qfxt), atomic);
}

void CodeGenLLVM::store_masked(llvm::Value *byte_ptr,
llvm::Type *byte_ptr_ty,
void CodeGenLLVM::store_masked(llvm::Value *ptr,
llvm::Type *ty,
uint64 mask,
llvm::Value *value,
bool atomic) {
if (!mask) {
// do not store anything
return;
}
auto physical_type = byte_ptr_ty;
uint64 full_mask = (~(uint64)0) >> (64 - physical_type->getIntegerBitWidth());
uint64 full_mask = (~(uint64)0) >> (64 - ty->getIntegerBitWidth());
if ((!atomic || prog->config.quant_opt_atomic_demotion) &&
((mask & full_mask) == full_mask)) {
builder->CreateStore(value, byte_ptr);
builder->CreateStore(value, ptr);
return;
}
create_call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "",
physical_type->getIntegerBitWidth()),
{byte_ptr, tlctx->get_constant(mask),
builder->CreateIntCast(value, physical_type, false)});
ty->getIntegerBitWidth()),
{ptr, tlctx->get_constant(mask),
builder->CreateIntCast(value, ty, false)});
}

llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent,
Expand All @@ -144,9 +143,10 @@ llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent,
tlctx->get_constant(0));
}

llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val,
Type *input_type,
Type *output_type) {
llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(
llvm::Value *val,
Type *input_type,
llvm::Type *output_type) {
QuantIntType *qit = nullptr;
if (auto qfxt = input_type->cast<QuantFixedType>()) {
qit = qfxt->get_digits_type()->as<QuantIntType>();
Expand All @@ -159,53 +159,47 @@ llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val,
val, tlctx->get_constant(qit->get_compute_type(),
uint64((1ULL << qit->get_num_bits()) - 1)));
}
val = builder->CreateZExt(val, llvm_type(output_type));
val = builder->CreateZExt(val, output_type);
return val;
}

void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
auto bit_struct_snode = stmt->get_bit_struct_snode();
auto bit_struct_physical_type =
bit_struct_snode->dt->as<BitStructType>()->get_physical_type();

int bit_struct_num_non_exponent_children = 0;
for (auto &ch : bit_struct_snode->ch) {
if (ch->exponent_users.empty()) {
bit_struct_num_non_exponent_children++;
auto bit_struct = stmt->get_bit_struct();
auto physical_type = llvm_type(bit_struct->get_physical_type());

int num_non_exponent_children = 0;
for (int i = 0; i < bit_struct->get_num_members(); i++) {
if (bit_struct->get_member_exponent_users(i).empty()) {
num_non_exponent_children++;
}
}
bool store_all_components = false;
if (prog->config.quant_opt_atomic_demotion &&
stmt->ch_ids.size() == bit_struct_num_non_exponent_children) {
stmt->ch_ids.size() == num_non_exponent_children) {
stmt->is_atomic = false;
store_all_components = true;
}

bool has_shared_exponent = false;
for (auto ch_id : stmt->ch_ids) {
if (bit_struct_snode->ch[ch_id]->owns_shared_exponent) {
if (bit_struct->get_member_owns_shared_exponent(ch_id)) {
has_shared_exponent = true;
}
}
// TODO: what about storing only shared-exponent floating-point SNodes
// that don't own the shared exponent?

if (has_shared_exponent) {
store_quant_floats_with_shared_exponents(stmt);
}

llvm::Value *bit_struct_val = nullptr;
for (int i = 0; i < stmt->ch_ids.size(); i++) {
auto ch_id = stmt->ch_ids[i];
auto val = llvm_val[stmt->values[i]];
auto &ch = bit_struct_snode->ch[ch_id];
if (has_shared_exponent && ch->exp_snode != nullptr &&
ch->exp_snode->exponent_users.size() > 1) {
auto exp = bit_struct->get_member_exponent(ch_id);
if (exp != -1 && bit_struct->get_member_exponent_users(exp).size() > 1) {
// already handled in store_quant_floats_with_shared_exponents
continue;
}
auto dtype = ch->dt;

auto dtype = bit_struct->get_member_type(ch_id);
auto val = llvm_val[stmt->values[i]];
if (auto qflt = dtype->cast<QuantFloatType>()) {
// Quant float type with non-shared exponent.
llvm::Value *digit_bits = nullptr;
Expand Down Expand Up @@ -243,20 +237,14 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
builder->CreateLShr(sign_bit, 31 - qflt->get_digit_bits()));
}

auto digits_snode = ch.get();
auto exponent_snode = digits_snode->exp_snode;

auto exponent_offset = get_exponent_offset(exponent_bits, qflt);
exponent_bits = builder->CreateSub(exponent_bits, exponent_offset);
exponent_bits =
create_call("max_i32", {exponent_bits, tlctx->get_constant(0)});

// Compute the bit pointer of the exponent bits.
TI_ASSERT(digits_snode->parent == exponent_snode->parent);

val = builder->CreateBitCast(exponent_bits,
llvm_type(bit_struct_physical_type));
val = builder->CreateShl(val, exponent_snode->bit_offset);
val = builder->CreateBitCast(exponent_bits, physical_type);
val = builder->CreateShl(val, bit_struct->get_member_bit_offset(exp));

if (bit_struct_val == nullptr) {
bit_struct_val = val;
Expand All @@ -272,12 +260,11 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
tlctx->get_constant(0));
val = builder->CreateSelect(exp_non_zero, digit_bits,
tlctx->get_constant(0));
val = builder->CreateBitCast(val, llvm_type(bit_struct_physical_type));
val = builder->CreateShl(val, digits_snode->bit_offset);
val = builder->CreateBitCast(val, physical_type);
val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id));
} else {
val = quant_int_or_quant_fixed_to_bits(val, dtype,
bit_struct_physical_type);
val = builder->CreateShl(val, bit_struct_snode->ch[ch_id]->bit_offset);
val = quant_int_or_quant_fixed_to_bits(val, dtype, physical_type);
val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id));
}

if (bit_struct_val == nullptr) {
Expand All @@ -292,68 +279,56 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
} else {
// Create a mask and use a single (atomic)CAS
uint64 mask = 0;
for (auto &ch_id : stmt->ch_ids) {
auto &ch = bit_struct_snode->ch[ch_id];
if (has_shared_exponent && ch->exp_snode != nullptr &&
ch->exp_snode->exponent_users.size() > 1) {
for (int i = 0; i < stmt->ch_ids.size(); i++) {
auto ch_id = stmt->ch_ids[i];
auto exp = bit_struct->get_member_exponent(ch_id);
if (exp != -1 && bit_struct->get_member_exponent_users(exp).size() > 1) {
// already handled in store_quant_floats_with_shared_exponents
continue;
}
auto dtype = ch->dt;
auto dtype = bit_struct->get_member_type(ch_id);
QuantIntType *qit = nullptr;
if (auto qflt = dtype->cast<QuantFloatType>()) {
auto exp = qflt->get_exponent_type();
auto exponent_qit = exp->as<QuantIntType>();
auto exponent_snode = ch->exp_snode;
auto exponent_qit = qflt->get_exponent_type()->as<QuantIntType>();
update_mask(mask, exponent_qit->get_num_bits(),
exponent_snode->bit_offset);
bit_struct->get_member_bit_offset(exp));
qit = qflt->get_digits_type()->as<QuantIntType>();
} else if (auto qfxt = dtype->cast<QuantFixedType>()) {
qit = qfxt->get_digits_type()->as<QuantIntType>();
} else {
qit = dtype->as<QuantIntType>();
}
update_mask(mask, qit->get_num_bits(), ch->bit_offset);
update_mask(mask, qit->get_num_bits(),
bit_struct->get_member_bit_offset(ch_id));
}
store_masked(llvm_val[stmt->ptr], llvm_type(bit_struct_physical_type), mask,
bit_struct_val, stmt->is_atomic);
store_masked(llvm_val[stmt->ptr], physical_type, mask, bit_struct_val,
stmt->is_atomic);
}
}

void CodeGenLLVM::store_quant_floats_with_shared_exponents(
BitStructStoreStmt *stmt) {
// handle each exponent separately
auto snode = stmt->get_bit_struct_snode();
auto bit_struct = snode->dt->as<BitStructType>();
auto bit_struct_physical_type = bit_struct->get_physical_type();
auto local_bit_struct = builder->CreateLoad(
#ifdef TI_LLVM_15
llvm_type(bit_struct_physical_type),
#endif
llvm_val[stmt->ptr]);
auto bit_struct = stmt->get_bit_struct();
auto physical_type = llvm_type(bit_struct->get_physical_type());
auto physical_value = builder->CreateLoad(physical_type, llvm_val[stmt->ptr]);
// fuse all stores into a masked store
llvm::Value *masked_val = nullptr;
uint64 mask = 0;
for (int i = 0; i < (int)snode->ch.size(); i++) {
if (snode->ch[i]->exponent_users.empty())
for (int i = 0; i < bit_struct->get_num_members(); i++) {
auto &exponent_users = bit_struct->get_member_exponent_users(i);
// make sure i-th member is a shared exponent
if (exponent_users.size() < 2)
continue;
// ch[i] must be an exponent SNode
auto &exp = snode->ch[i];
if (exp->exponent_users.size() == 1) {
// non-shared
continue;
}
// load all floats
// load all floats with the shared exponent
std::vector<llvm::Value *> floats;
for (auto &user : exp->exponent_users) {
auto ch_id = snode->child_id(user);
for (auto user : exponent_users) {
if (auto input =
std::find(stmt->ch_ids.begin(), stmt->ch_ids.end(), ch_id);
std::find(stmt->ch_ids.begin(), stmt->ch_ids.end(), user);
input != stmt->ch_ids.end()) {
floats.push_back(llvm_val[stmt->values[input - stmt->ch_ids.begin()]]);
} else {
floats.push_back(
extract_quant_float(local_bit_struct, bit_struct, ch_id));
floats.push_back(extract_quant_float(physical_value, bit_struct, user));
}
}
// convert to i32 for bit operations
Expand All @@ -368,7 +343,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents(
}
}

auto first_qflt = exp->exponent_users[0]->dt->as<QuantFloatType>();
auto first_qflt =
bit_struct->get_member_type(exponent_users[0])->as<QuantFloatType>();
auto exponent_offset = get_exponent_offset(max_exp_bits, first_qflt);

auto max_exp_bits_to_store =
Expand All @@ -378,33 +354,26 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents(
create_call("max_i32", {max_exp_bits_to_store, tlctx->get_constant(0)});

// store the exponent
auto val = builder->CreateZExt(
max_exp_bits_to_store,
llvm_type(bit_struct_physical_type->get_compute_type()));
val = builder->CreateShl(val, exp->bit_offset);
auto bit_offset = bit_struct->get_member_bit_offset(i);
auto val = builder->CreateZExt(max_exp_bits_to_store, physical_type);
val = builder->CreateShl(val, bit_offset);
if (masked_val == nullptr) {
masked_val = val;
} else {
masked_val = builder->CreateOr(masked_val, val);
}
update_mask(mask, exp->dt->as<QuantIntType>()->get_num_bits(),
exp->bit_offset);
update_mask(
mask,
bit_struct->get_member_type(i)->as<QuantIntType>()->get_num_bits(),
bit_offset);

for (int c = 0; c < (int)exp->exponent_users.size(); c++) {
auto user = exp->exponent_users[c];
auto ch_id = snode->child_id(user);
for (int c = 0; c < (int)exponent_users.size(); c++) {
auto user = exponent_users[c];
auto digits =
extract_digits_from_f32_with_shared_exponent(floats[c], max_exp_bits);
auto digits_snode = snode->ch[ch_id].get();
auto qflt = digits_snode->dt->as<QuantFloatType>();
auto digits_bit_offset = digits_snode->bit_offset;

int right_shift_bits =
23 + qflt->get_is_signed() - qflt->get_digit_bits();
if (!qflt->get_is_signed()) {
// unsigned
right_shift_bits += 1;
}
auto qflt = bit_struct->get_member_type(user)->as<QuantFloatType>();
auto digits_bit_offset = bit_struct->get_member_bit_offset(user);
auto right_shift_bits = 24 - qflt->get_digit_bits();

// round to nearest
digits = builder->CreateAdd(
Expand All @@ -426,16 +395,16 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents(
}

// store the digits
val = builder->CreateZExt(digits, llvm_type(bit_struct_physical_type));
val = builder->CreateZExt(digits, physical_type);
val = builder->CreateShl(val, digits_bit_offset);
masked_val = builder->CreateOr(masked_val, val);
auto num_digit_bits =
qflt->get_digits_type()->as<QuantIntType>()->get_num_bits();
update_mask(mask, num_digit_bits, digits_bit_offset);
}
}
store_masked(llvm_val[stmt->ptr], llvm_type(bit_struct_physical_type), mask,
masked_val, stmt->is_atomic);
store_masked(llvm_val[stmt->ptr], physical_type, mask, masked_val,
stmt->is_atomic);
}

llvm::Value *CodeGenLLVM::extract_exponent_from_f32(llvm::Value *f) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ int LoopIndexStmt::max_num_bits() const {
}
}

SNode *BitStructStoreStmt::get_bit_struct_snode() const {
return ptr->as<SNodeLookupStmt>()->snode;
BitStructType *BitStructStoreStmt::get_bit_struct() const {
return ptr->as<SNodeLookupStmt>()->snode->dt->as<BitStructType>();
}

TLANG_NAMESPACE_END
2 changes: 1 addition & 1 deletion taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,7 @@ class BitStructStoreStmt : public Stmt {
TI_STMT_REG_FIELDS;
}

SNode *get_bit_struct_snode() const;
BitStructType *get_bit_struct() const;

bool common_statement_eliminable() const override {
return false;
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/optimize_bit_struct_stores.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class DemoteAtomicBitStructStores : public BasicStmtVisitor {
} else if (current_offloaded->task_type == OffloadedTaskType::range_for ||
current_offloaded->task_type == OffloadedTaskType::mesh_for ||
current_offloaded->task_type == OffloadedTaskType::struct_for) {
auto *snode = stmt->get_bit_struct_snode();
auto *snode = stmt->ptr->as<SNodeLookupStmt>()->snode;
// Find the nearest non-bit-level ancestor
while (snode->is_bit_level) {
snode = snode->parent;
Expand Down

0 comments on commit 2f3ce21

Please sign in to comment.