Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] [refactor] Decouple quant from SNode 7/n: Rewrite BitStructStoreStmt codegen without SNode #5475

Merged
merged 2 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *value,
bool atomic);

void store_masked(llvm::Value *byte_ptr,
llvm::Type *byte_ptr_ty,
void store_masked(llvm::Value *ptr,
llvm::Type *ty,
uint64 mask,
llvm::Value *value,
bool atomic);
Expand All @@ -244,7 +244,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Value *quant_int_or_quant_fixed_to_bits(llvm::Value *val,
Type *input_type,
Type *output_type);
llvm::Type *output_type);

void visit(BitStructStoreStmt *stmt) override;

Expand Down
171 changes: 70 additions & 101 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,25 @@ void CodeGenLLVM::store_quant_fixed(llvm::Value *bit_ptr,
to_quant_fixed(value, qfxt), atomic);
}

void CodeGenLLVM::store_masked(llvm::Value *byte_ptr,
llvm::Type *byte_ptr_ty,
void CodeGenLLVM::store_masked(llvm::Value *ptr,
llvm::Type *ty,
uint64 mask,
llvm::Value *value,
bool atomic) {
if (!mask) {
// do not store anything
return;
}
auto physical_type = byte_ptr_ty;
uint64 full_mask = (~(uint64)0) >> (64 - physical_type->getIntegerBitWidth());
uint64 full_mask = (~(uint64)0) >> (64 - ty->getIntegerBitWidth());
if ((!atomic || prog->config.quant_opt_atomic_demotion) &&
((mask & full_mask) == full_mask)) {
builder->CreateStore(value, byte_ptr);
builder->CreateStore(value, ptr);
return;
}
create_call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "",
physical_type->getIntegerBitWidth()),
{byte_ptr, tlctx->get_constant(mask),
builder->CreateIntCast(value, physical_type, false)});
ty->getIntegerBitWidth()),
{ptr, tlctx->get_constant(mask),
builder->CreateIntCast(value, ty, false)});
}

llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent,
Expand All @@ -144,9 +143,10 @@ llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent,
tlctx->get_constant(0));
}

llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val,
Type *input_type,
Type *output_type) {
llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(
llvm::Value *val,
Type *input_type,
llvm::Type *output_type) {
QuantIntType *qit = nullptr;
if (auto qfxt = input_type->cast<QuantFixedType>()) {
qit = qfxt->get_digits_type()->as<QuantIntType>();
Expand All @@ -159,53 +159,47 @@ llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val,
val, tlctx->get_constant(qit->get_compute_type(),
uint64((1ULL << qit->get_num_bits()) - 1)));
}
val = builder->CreateZExt(val, llvm_type(output_type));
val = builder->CreateZExt(val, output_type);
return val;
}

void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
auto bit_struct_snode = stmt->get_bit_struct_snode();
auto bit_struct_physical_type =
bit_struct_snode->dt->as<BitStructType>()->get_physical_type();

int bit_struct_num_non_exponent_children = 0;
for (auto &ch : bit_struct_snode->ch) {
if (ch->exponent_users.empty()) {
bit_struct_num_non_exponent_children++;
auto bit_struct = stmt->get_bit_struct();
auto physical_type = llvm_type(bit_struct->get_physical_type());

int num_non_exponent_children = 0;
for (int i = 0; i < bit_struct->get_num_members(); i++) {
if (bit_struct->get_member_exponent_users(i).empty()) {
num_non_exponent_children++;
}
}
bool store_all_components = false;
if (prog->config.quant_opt_atomic_demotion &&
stmt->ch_ids.size() == bit_struct_num_non_exponent_children) {
stmt->ch_ids.size() == num_non_exponent_children) {
stmt->is_atomic = false;
store_all_components = true;
}

bool has_shared_exponent = false;
for (auto ch_id : stmt->ch_ids) {
if (bit_struct_snode->ch[ch_id]->owns_shared_exponent) {
if (bit_struct->get_member_owns_shared_exponent(ch_id)) {
has_shared_exponent = true;
}
}
// TODO: what about storing only shared-exponent floating-point SNodes
// that don't own the shared exponent?

if (has_shared_exponent) {
store_quant_floats_with_shared_exponents(stmt);
}

llvm::Value *bit_struct_val = nullptr;
for (int i = 0; i < stmt->ch_ids.size(); i++) {
auto ch_id = stmt->ch_ids[i];
auto val = llvm_val[stmt->values[i]];
auto &ch = bit_struct_snode->ch[ch_id];
if (has_shared_exponent && ch->exp_snode != nullptr &&
ch->exp_snode->exponent_users.size() > 1) {
auto exp = bit_struct->get_member_exponent(ch_id);
if (exp != -1 && bit_struct->get_member_exponent_users(exp).size() > 1) {
// already handled in store_quant_floats_with_shared_exponents
continue;
}
auto dtype = ch->dt;

auto dtype = bit_struct->get_member_type(ch_id);
auto val = llvm_val[stmt->values[i]];
if (auto qflt = dtype->cast<QuantFloatType>()) {
// Quant float type with non-shared exponent.
llvm::Value *digit_bits = nullptr;
Expand Down Expand Up @@ -243,20 +237,14 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
builder->CreateLShr(sign_bit, 31 - qflt->get_digit_bits()));
}

auto digits_snode = ch.get();
auto exponent_snode = digits_snode->exp_snode;

auto exponent_offset = get_exponent_offset(exponent_bits, qflt);
exponent_bits = builder->CreateSub(exponent_bits, exponent_offset);
exponent_bits =
create_call("max_i32", {exponent_bits, tlctx->get_constant(0)});

// Compute the bit pointer of the exponent bits.
TI_ASSERT(digits_snode->parent == exponent_snode->parent);

val = builder->CreateBitCast(exponent_bits,
llvm_type(bit_struct_physical_type));
val = builder->CreateShl(val, exponent_snode->bit_offset);
val = builder->CreateBitCast(exponent_bits, physical_type);
val = builder->CreateShl(val, bit_struct->get_member_bit_offset(exp));

if (bit_struct_val == nullptr) {
bit_struct_val = val;
Expand All @@ -272,12 +260,11 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
tlctx->get_constant(0));
val = builder->CreateSelect(exp_non_zero, digit_bits,
tlctx->get_constant(0));
val = builder->CreateBitCast(val, llvm_type(bit_struct_physical_type));
val = builder->CreateShl(val, digits_snode->bit_offset);
val = builder->CreateBitCast(val, physical_type);
val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id));
} else {
val = quant_int_or_quant_fixed_to_bits(val, dtype,
bit_struct_physical_type);
val = builder->CreateShl(val, bit_struct_snode->ch[ch_id]->bit_offset);
val = quant_int_or_quant_fixed_to_bits(val, dtype, physical_type);
val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id));
}

if (bit_struct_val == nullptr) {
Expand All @@ -292,68 +279,56 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) {
} else {
// Create a mask and use a single (atomic)CAS
uint64 mask = 0;
for (auto &ch_id : stmt->ch_ids) {
auto &ch = bit_struct_snode->ch[ch_id];
if (has_shared_exponent && ch->exp_snode != nullptr &&
ch->exp_snode->exponent_users.size() > 1) {
for (int i = 0; i < stmt->ch_ids.size(); i++) {
auto ch_id = stmt->ch_ids[i];
auto exp = bit_struct->get_member_exponent(ch_id);
if (exp != -1 && bit_struct->get_member_exponent_users(exp).size() > 1) {
// already handled in store_quant_floats_with_shared_exponents
continue;
}
auto dtype = ch->dt;
auto dtype = bit_struct->get_member_type(ch_id);
QuantIntType *qit = nullptr;
if (auto qflt = dtype->cast<QuantFloatType>()) {
auto exp = qflt->get_exponent_type();
auto exponent_qit = exp->as<QuantIntType>();
auto exponent_snode = ch->exp_snode;
auto exponent_qit = qflt->get_exponent_type()->as<QuantIntType>();
update_mask(mask, exponent_qit->get_num_bits(),
exponent_snode->bit_offset);
bit_struct->get_member_bit_offset(exp));
qit = qflt->get_digits_type()->as<QuantIntType>();
} else if (auto qfxt = dtype->cast<QuantFixedType>()) {
qit = qfxt->get_digits_type()->as<QuantIntType>();
} else {
qit = dtype->as<QuantIntType>();
}
update_mask(mask, qit->get_num_bits(), ch->bit_offset);
update_mask(mask, qit->get_num_bits(),
bit_struct->get_member_bit_offset(ch_id));
}
store_masked(llvm_val[stmt->ptr], llvm_type(bit_struct_physical_type), mask,
bit_struct_val, stmt->is_atomic);
store_masked(llvm_val[stmt->ptr], physical_type, mask, bit_struct_val,
stmt->is_atomic);
}
}

void CodeGenLLVM::store_quant_floats_with_shared_exponents(
BitStructStoreStmt *stmt) {
// handle each exponent separately
auto snode = stmt->get_bit_struct_snode();
auto bit_struct = snode->dt->as<BitStructType>();
auto bit_struct_physical_type = bit_struct->get_physical_type();
auto local_bit_struct = builder->CreateLoad(
#ifdef TI_LLVM_15
llvm_type(bit_struct_physical_type),
#endif
llvm_val[stmt->ptr]);
auto bit_struct = stmt->get_bit_struct();
auto physical_type = llvm_type(bit_struct->get_physical_type());
auto physical_value = builder->CreateLoad(physical_type, llvm_val[stmt->ptr]);
// fuse all stores into a masked store
llvm::Value *masked_val = nullptr;
uint64 mask = 0;
for (int i = 0; i < (int)snode->ch.size(); i++) {
if (snode->ch[i]->exponent_users.empty())
for (int i = 0; i < bit_struct->get_num_members(); i++) {
auto &exponent_users = bit_struct->get_member_exponent_users(i);
// make sure i-th member is a shared exponent
if (exponent_users.size() < 2)
continue;
// ch[i] must be an exponent SNode
auto &exp = snode->ch[i];
if (exp->exponent_users.size() == 1) {
// non-shared
continue;
}
// load all floats
// load all floats with the shared exponent
std::vector<llvm::Value *> floats;
for (auto &user : exp->exponent_users) {
auto ch_id = snode->child_id(user);
for (auto user : exponent_users) {
if (auto input =
std::find(stmt->ch_ids.begin(), stmt->ch_ids.end(), ch_id);
std::find(stmt->ch_ids.begin(), stmt->ch_ids.end(), user);
input != stmt->ch_ids.end()) {
floats.push_back(llvm_val[stmt->values[input - stmt->ch_ids.begin()]]);
} else {
floats.push_back(
extract_quant_float(local_bit_struct, bit_struct, ch_id));
floats.push_back(extract_quant_float(physical_value, bit_struct, user));
}
}
// convert to i32 for bit operations
Expand All @@ -368,7 +343,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents(
}
}

auto first_qflt = exp->exponent_users[0]->dt->as<QuantFloatType>();
auto first_qflt =
bit_struct->get_member_type(exponent_users[0])->as<QuantFloatType>();
auto exponent_offset = get_exponent_offset(max_exp_bits, first_qflt);

auto max_exp_bits_to_store =
Expand All @@ -378,33 +354,26 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents(
create_call("max_i32", {max_exp_bits_to_store, tlctx->get_constant(0)});

// store the exponent
auto val = builder->CreateZExt(
max_exp_bits_to_store,
llvm_type(bit_struct_physical_type->get_compute_type()));
val = builder->CreateShl(val, exp->bit_offset);
auto bit_offset = bit_struct->get_member_bit_offset(i);
auto val = builder->CreateZExt(max_exp_bits_to_store, physical_type);
val = builder->CreateShl(val, bit_offset);
if (masked_val == nullptr) {
masked_val = val;
} else {
masked_val = builder->CreateOr(masked_val, val);
}
update_mask(mask, exp->dt->as<QuantIntType>()->get_num_bits(),
exp->bit_offset);
update_mask(
mask,
bit_struct->get_member_type(i)->as<QuantIntType>()->get_num_bits(),
bit_offset);

for (int c = 0; c < (int)exp->exponent_users.size(); c++) {
auto user = exp->exponent_users[c];
auto ch_id = snode->child_id(user);
for (int c = 0; c < (int)exponent_users.size(); c++) {
auto user = exponent_users[c];
auto digits =
extract_digits_from_f32_with_shared_exponent(floats[c], max_exp_bits);
auto digits_snode = snode->ch[ch_id].get();
auto qflt = digits_snode->dt->as<QuantFloatType>();
auto digits_bit_offset = digits_snode->bit_offset;

int right_shift_bits =
23 + qflt->get_is_signed() - qflt->get_digit_bits();
if (!qflt->get_is_signed()) {
// unsigned
right_shift_bits += 1;
}
auto qflt = bit_struct->get_member_type(user)->as<QuantFloatType>();
auto digits_bit_offset = bit_struct->get_member_bit_offset(user);
auto right_shift_bits = 24 - qflt->get_digit_bits();

// round to nearest
digits = builder->CreateAdd(
Expand All @@ -426,16 +395,16 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents(
}

// store the digits
val = builder->CreateZExt(digits, llvm_type(bit_struct_physical_type));
val = builder->CreateZExt(digits, physical_type);
val = builder->CreateShl(val, digits_bit_offset);
masked_val = builder->CreateOr(masked_val, val);
auto num_digit_bits =
qflt->get_digits_type()->as<QuantIntType>()->get_num_bits();
update_mask(mask, num_digit_bits, digits_bit_offset);
}
}
store_masked(llvm_val[stmt->ptr], llvm_type(bit_struct_physical_type), mask,
masked_val, stmt->is_atomic);
store_masked(llvm_val[stmt->ptr], physical_type, mask, masked_val,
stmt->is_atomic);
}

llvm::Value *CodeGenLLVM::extract_exponent_from_f32(llvm::Value *f) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ int LoopIndexStmt::max_num_bits() const {
}
}

SNode *BitStructStoreStmt::get_bit_struct_snode() const {
return ptr->as<SNodeLookupStmt>()->snode;
BitStructType *BitStructStoreStmt::get_bit_struct() const {
return ptr->as<SNodeLookupStmt>()->snode->dt->as<BitStructType>();
}

TLANG_NAMESPACE_END
2 changes: 1 addition & 1 deletion taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,7 @@ class BitStructStoreStmt : public Stmt {
TI_STMT_REG_FIELDS;
}

SNode *get_bit_struct_snode() const;
BitStructType *get_bit_struct() const;

bool common_statement_eliminable() const override {
return false;
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/optimize_bit_struct_stores.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class DemoteAtomicBitStructStores : public BasicStmtVisitor {
} else if (current_offloaded->task_type == OffloadedTaskType::range_for ||
current_offloaded->task_type == OffloadedTaskType::mesh_for ||
current_offloaded->task_type == OffloadedTaskType::struct_for) {
auto *snode = stmt->get_bit_struct_snode();
auto *snode = stmt->ptr->as<SNodeLookupStmt>()->snode;
// Find the nearest non-bit-level ancestor
while (snode->is_bit_level) {
snode = snode->parent;
Expand Down