From 073aa72fededf413a0885bcb83d0374c9c2025eb Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Tue, 19 Jul 2022 20:10:34 +0800 Subject: [PATCH 1/2] [type] [refactor] Rewrite BitStructStoreStmt codegen without SNode --- taichi/codegen/llvm/codegen_llvm.h | 6 +- taichi/codegen/llvm/codegen_llvm_quant.cpp | 160 +++++++----------- taichi/ir/statements.cpp | 4 +- taichi/ir/statements.h | 2 +- .../transforms/optimize_bit_struct_stores.cpp | 2 +- 5 files changed, 68 insertions(+), 106 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index c4d052a220197..e28218ddffbad 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -234,8 +234,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *value, bool atomic); - void store_masked(llvm::Value *byte_ptr, - llvm::Type *byte_ptr_ty, + void store_masked(llvm::Value *ptr, + llvm::Type *ty, uint64 mask, llvm::Value *value, bool atomic); @@ -244,7 +244,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *quant_int_or_quant_fixed_to_bits(llvm::Value *val, Type *input_type, - Type *output_type); + llvm::Type *output_type); void visit(BitStructStoreStmt *stmt) override; diff --git a/taichi/codegen/llvm/codegen_llvm_quant.cpp b/taichi/codegen/llvm/codegen_llvm_quant.cpp index 6acb4e8753570..99a457e159393 100644 --- a/taichi/codegen/llvm/codegen_llvm_quant.cpp +++ b/taichi/codegen/llvm/codegen_llvm_quant.cpp @@ -110,8 +110,8 @@ void CodeGenLLVM::store_quant_fixed(llvm::Value *bit_ptr, to_quant_fixed(value, qfxt), atomic); } -void CodeGenLLVM::store_masked(llvm::Value *byte_ptr, - llvm::Type *byte_ptr_ty, +void CodeGenLLVM::store_masked(llvm::Value *ptr, + llvm::Type *ty, uint64 mask, llvm::Value *value, bool atomic) { @@ -119,17 +119,16 @@ void CodeGenLLVM::store_masked(llvm::Value *byte_ptr, // do not store anything return; } - auto physical_type = byte_ptr_ty; - uint64 full_mask = (~(uint64)0) >> (64 - physical_type->getIntegerBitWidth()); + uint64 full_mask = (~(uint64)0) >> (64 - ty->getIntegerBitWidth()); if ((!atomic || prog->config.quant_opt_atomic_demotion) && ((mask & full_mask) == full_mask)) { - builder->CreateStore(value, byte_ptr); + builder->CreateStore(value, ptr); return; } create_call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "", - physical_type->getIntegerBitWidth()), - {byte_ptr, tlctx->get_constant(mask), - builder->CreateIntCast(value, physical_type, false)}); + ty->getIntegerBitWidth()), + {ptr, tlctx->get_constant(mask), + builder->CreateIntCast(value, ty, false)}); } llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent, @@ -144,9 +143,7 @@ llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent, tlctx->get_constant(0)); } -llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val, - Type *input_type, - Type *output_type) { +llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val, Type *input_type, llvm::Type *output_type) { QuantIntType *qit = nullptr; if (auto qfxt = input_type->cast()) { qit = qfxt->get_digits_type()->as(); @@ -159,37 +156,33 @@ llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val, val, tlctx->get_constant(qit->get_compute_type(), uint64((1ULL << qit->get_num_bits()) - 1))); } - val = builder->CreateZExt(val, llvm_type(output_type)); + val = builder->CreateZExt(val, output_type); return val; } void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { - auto bit_struct_snode = stmt->get_bit_struct_snode(); - auto bit_struct_physical_type = - bit_struct_snode->dt->as()->get_physical_type(); - - int bit_struct_num_non_exponent_children = 0; - for (auto &ch : bit_struct_snode->ch) { - if (ch->exponent_users.empty()) { - bit_struct_num_non_exponent_children++; + auto bit_struct = stmt->get_bit_struct(); + auto physical_type = llvm_type(bit_struct->get_physical_type()); + + int num_non_exponent_children = 0; + for (int i = 0; i < bit_struct->get_num_members(); i++) { + if (bit_struct->get_member_exponent_users(i).empty()) { + num_non_exponent_children++; } } bool store_all_components = false; if (prog->config.quant_opt_atomic_demotion && - stmt->ch_ids.size() == bit_struct_num_non_exponent_children) { + stmt->ch_ids.size() == num_non_exponent_children) { stmt->is_atomic = false; store_all_components = true; } bool has_shared_exponent = false; for (auto ch_id : stmt->ch_ids) { - if (bit_struct_snode->ch[ch_id]->owns_shared_exponent) { + if (bit_struct->get_member_owns_shared_exponent(ch_id)) { has_shared_exponent = true; } } - // TODO: what about storing only shared-exponent floating-point SNodes - // that don't own the shared exponent? - if (has_shared_exponent) { store_quant_floats_with_shared_exponents(stmt); } @@ -197,15 +190,13 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { llvm::Value *bit_struct_val = nullptr; for (int i = 0; i < stmt->ch_ids.size(); i++) { auto ch_id = stmt->ch_ids[i]; - auto val = llvm_val[stmt->values[i]]; - auto &ch = bit_struct_snode->ch[ch_id]; - if (has_shared_exponent && ch->exp_snode != nullptr && - ch->exp_snode->exponent_users.size() > 1) { + auto exp = bit_struct->get_member_exponent(ch_id); + if (exp != -1 && bit_struct->get_member_exponent_users(exp).size() > 1) { // already handled in store_quant_floats_with_shared_exponents continue; } - auto dtype = ch->dt; - + auto dtype = bit_struct->get_member_type(ch_id); + auto val = llvm_val[stmt->values[i]]; if (auto qflt = dtype->cast()) { // Quant float type with non-shared exponent. llvm::Value *digit_bits = nullptr; @@ -243,20 +234,14 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { builder->CreateLShr(sign_bit, 31 - qflt->get_digit_bits())); } - auto digits_snode = ch.get(); - auto exponent_snode = digits_snode->exp_snode; - auto exponent_offset = get_exponent_offset(exponent_bits, qflt); exponent_bits = builder->CreateSub(exponent_bits, exponent_offset); exponent_bits = create_call("max_i32", {exponent_bits, tlctx->get_constant(0)}); // Compute the bit pointer of the exponent bits. - TI_ASSERT(digits_snode->parent == exponent_snode->parent); - - val = builder->CreateBitCast(exponent_bits, - llvm_type(bit_struct_physical_type)); - val = builder->CreateShl(val, exponent_snode->bit_offset); + val = builder->CreateBitCast(exponent_bits, physical_type); + val = builder->CreateShl(val, bit_struct->get_member_bit_offset(exp)); if (bit_struct_val == nullptr) { bit_struct_val = val; @@ -272,12 +257,11 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { tlctx->get_constant(0)); val = builder->CreateSelect(exp_non_zero, digit_bits, tlctx->get_constant(0)); - val = builder->CreateBitCast(val, llvm_type(bit_struct_physical_type)); - val = builder->CreateShl(val, digits_snode->bit_offset); + val = builder->CreateBitCast(val, physical_type); + val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id)); } else { - val = quant_int_or_quant_fixed_to_bits(val, dtype, - bit_struct_physical_type); - val = builder->CreateShl(val, bit_struct_snode->ch[ch_id]->bit_offset); + val = quant_int_or_quant_fixed_to_bits(val, dtype, physical_type); + val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id)); } if (bit_struct_val == nullptr) { @@ -292,68 +276,55 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { } else { // Create a mask and use a single (atomic)CAS uint64 mask = 0; - for (auto &ch_id : stmt->ch_ids) { - auto &ch = bit_struct_snode->ch[ch_id]; - if (has_shared_exponent && ch->exp_snode != nullptr && - ch->exp_snode->exponent_users.size() > 1) { + for (int i = 0; i < stmt->ch_ids.size(); i++) { + auto ch_id = stmt->ch_ids[i]; + auto exp = bit_struct->get_member_exponent(ch_id); + if (exp != -1 && bit_struct->get_member_exponent_users(exp).size() > 1) { // already handled in store_quant_floats_with_shared_exponents continue; } - auto dtype = ch->dt; + auto dtype = bit_struct->get_member_type(ch_id); QuantIntType *qit = nullptr; if (auto qflt = dtype->cast()) { - auto exp = qflt->get_exponent_type(); - auto exponent_qit = exp->as(); - auto exponent_snode = ch->exp_snode; + auto exponent_qit = qflt->get_exponent_type()->as(); update_mask(mask, exponent_qit->get_num_bits(), - exponent_snode->bit_offset); + bit_struct->get_member_bit_offset(exp)); qit = qflt->get_digits_type()->as(); } else if (auto qfxt = dtype->cast()) { qit = qfxt->get_digits_type()->as(); } else { qit = dtype->as(); } - update_mask(mask, qit->get_num_bits(), ch->bit_offset); + update_mask(mask, qit->get_num_bits(), bit_struct->get_member_bit_offset(ch_id)); } - store_masked(llvm_val[stmt->ptr], llvm_type(bit_struct_physical_type), mask, - bit_struct_val, stmt->is_atomic); + store_masked(llvm_val[stmt->ptr], physical_type, mask, bit_struct_val, stmt->is_atomic); } } void CodeGenLLVM::store_quant_floats_with_shared_exponents( BitStructStoreStmt *stmt) { // handle each exponent separately - auto snode = stmt->get_bit_struct_snode(); - auto bit_struct = snode->dt->as(); - auto bit_struct_physical_type = bit_struct->get_physical_type(); - auto local_bit_struct = builder->CreateLoad( -#ifdef TI_LLVM_15 - llvm_type(bit_struct_physical_type), -#endif - llvm_val[stmt->ptr]); + auto bit_struct = stmt->get_bit_struct(); + auto physical_type = llvm_type(bit_struct->get_physical_type()); + auto physical_value = builder->CreateLoad(physical_type, llvm_val[stmt->ptr]); // fuse all stores into a masked store llvm::Value *masked_val = nullptr; uint64 mask = 0; - for (int i = 0; i < (int)snode->ch.size(); i++) { - if (snode->ch[i]->exponent_users.empty()) + for (int i = 0; i < bit_struct->get_num_members(); i++) { + auto &exponent_users = bit_struct->get_member_exponent_users(i); + // make sure i-th member is a shared exponent + if (exponent_users.size() < 2) continue; - // ch[i] must be an exponent SNode - auto &exp = snode->ch[i]; - if (exp->exponent_users.size() == 1) { - // non-shared - continue; - } - // load all floats + // load all floats with the shared exponent std::vector floats; - for (auto &user : exp->exponent_users) { - auto ch_id = snode->child_id(user); + for (auto user : exponent_users) { if (auto input = - std::find(stmt->ch_ids.begin(), stmt->ch_ids.end(), ch_id); + std::find(stmt->ch_ids.begin(), stmt->ch_ids.end(), user); input != stmt->ch_ids.end()) { floats.push_back(llvm_val[stmt->values[input - stmt->ch_ids.begin()]]); } else { floats.push_back( - extract_quant_float(local_bit_struct, bit_struct, ch_id)); + extract_quant_float(physical_value, bit_struct, user)); } } // convert to i32 for bit operations @@ -368,7 +339,7 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( } } - auto first_qflt = exp->exponent_users[0]->dt->as(); + auto first_qflt = bit_struct->get_member_type(exponent_users[0])->as(); auto exponent_offset = get_exponent_offset(max_exp_bits, first_qflt); auto max_exp_bits_to_store = @@ -378,33 +349,24 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( create_call("max_i32", {max_exp_bits_to_store, tlctx->get_constant(0)}); // store the exponent - auto val = builder->CreateZExt( - max_exp_bits_to_store, - llvm_type(bit_struct_physical_type->get_compute_type())); - val = builder->CreateShl(val, exp->bit_offset); + auto bit_offset = bit_struct->get_member_bit_offset(i); + auto val = builder->CreateZExt(max_exp_bits_to_store, physical_type); + val = builder->CreateShl(val, bit_offset); if (masked_val == nullptr) { masked_val = val; } else { masked_val = builder->CreateOr(masked_val, val); } - update_mask(mask, exp->dt->as()->get_num_bits(), - exp->bit_offset); + update_mask(mask, bit_struct->get_member_type(i)->as()->get_num_bits(), + bit_offset); - for (int c = 0; c < (int)exp->exponent_users.size(); c++) { - auto user = exp->exponent_users[c]; - auto ch_id = snode->child_id(user); + for (int c = 0; c < (int)exponent_users.size(); c++) { + auto user = exponent_users[c]; auto digits = extract_digits_from_f32_with_shared_exponent(floats[c], max_exp_bits); - auto digits_snode = snode->ch[ch_id].get(); - auto qflt = digits_snode->dt->as(); - auto digits_bit_offset = digits_snode->bit_offset; - - int right_shift_bits = - 23 + qflt->get_is_signed() - qflt->get_digit_bits(); - if (!qflt->get_is_signed()) { - // unsigned - right_shift_bits += 1; - } + auto qflt = bit_struct->get_member_type(user)->as(); + auto digits_bit_offset = bit_struct->get_member_bit_offset(user); + auto right_shift_bits = 24 - qflt->get_digit_bits(); // round to nearest digits = builder->CreateAdd( @@ -426,7 +388,7 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( } // store the digits - val = builder->CreateZExt(digits, llvm_type(bit_struct_physical_type)); + val = builder->CreateZExt(digits, physical_type); val = builder->CreateShl(val, digits_bit_offset); masked_val = builder->CreateOr(masked_val, val); auto num_digit_bits = @@ -434,7 +396,7 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( update_mask(mask, num_digit_bits, digits_bit_offset); } } - store_masked(llvm_val[stmt->ptr], llvm_type(bit_struct_physical_type), mask, + store_masked(llvm_val[stmt->ptr], physical_type, mask, masked_val, stmt->is_atomic); } diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 10d455031a9be..c901c32eba7e5 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -497,8 +497,8 @@ int LoopIndexStmt::max_num_bits() const { } } -SNode *BitStructStoreStmt::get_bit_struct_snode() const { - return ptr->as()->snode; +BitStructType *BitStructStoreStmt::get_bit_struct() const { + return ptr->as()->snode->dt->as(); } TLANG_NAMESPACE_END diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 0ffa885f1fb27..65ee7d401dfad 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1686,7 +1686,7 @@ class BitStructStoreStmt : public Stmt { TI_STMT_REG_FIELDS; } - SNode *get_bit_struct_snode() const; + BitStructType *get_bit_struct() const; bool common_statement_eliminable() const override { return false; diff --git a/taichi/transforms/optimize_bit_struct_stores.cpp b/taichi/transforms/optimize_bit_struct_stores.cpp index 0c771a4b15577..1993298717129 100644 --- a/taichi/transforms/optimize_bit_struct_stores.cpp +++ b/taichi/transforms/optimize_bit_struct_stores.cpp @@ -146,7 +146,7 @@ class DemoteAtomicBitStructStores : public BasicStmtVisitor { } else if (current_offloaded->task_type == OffloadedTaskType::range_for || current_offloaded->task_type == OffloadedTaskType::mesh_for || current_offloaded->task_type == OffloadedTaskType::struct_for) { - auto *snode = stmt->get_bit_struct_snode(); + auto *snode = stmt->ptr->as()->snode; // Find the nearest non-bit-level ancestor while (snode->is_bit_level) { snode = snode->parent; From 331a41f9595695ab4c0de72a5afdf311f5724582 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Jul 2022 12:14:23 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/llvm/codegen_llvm_quant.cpp | 27 ++++++++++++++-------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm_quant.cpp b/taichi/codegen/llvm/codegen_llvm_quant.cpp index 99a457e159393..2781b134b7b0e 100644 --- a/taichi/codegen/llvm/codegen_llvm_quant.cpp +++ b/taichi/codegen/llvm/codegen_llvm_quant.cpp @@ -143,7 +143,10 @@ llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent, tlctx->get_constant(0)); } -llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val, Type *input_type, llvm::Type *output_type) { +llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits( + llvm::Value *val, + Type *input_type, + llvm::Type *output_type) { QuantIntType *qit = nullptr; if (auto qfxt = input_type->cast()) { qit = qfxt->get_digits_type()->as(); @@ -295,9 +298,11 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { } else { qit = dtype->as(); } - update_mask(mask, qit->get_num_bits(), bit_struct->get_member_bit_offset(ch_id)); + update_mask(mask, qit->get_num_bits(), + bit_struct->get_member_bit_offset(ch_id)); } - store_masked(llvm_val[stmt->ptr], physical_type, mask, bit_struct_val, stmt->is_atomic); + store_masked(llvm_val[stmt->ptr], physical_type, mask, bit_struct_val, + stmt->is_atomic); } } @@ -323,8 +328,7 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( input != stmt->ch_ids.end()) { floats.push_back(llvm_val[stmt->values[input - stmt->ch_ids.begin()]]); } else { - floats.push_back( - extract_quant_float(physical_value, bit_struct, user)); + floats.push_back(extract_quant_float(physical_value, bit_struct, user)); } } // convert to i32 for bit operations @@ -339,7 +343,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( } } - auto first_qflt = bit_struct->get_member_type(exponent_users[0])->as(); + auto first_qflt = + bit_struct->get_member_type(exponent_users[0])->as(); auto exponent_offset = get_exponent_offset(max_exp_bits, first_qflt); auto max_exp_bits_to_store = @@ -357,8 +362,10 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( } else { masked_val = builder->CreateOr(masked_val, val); } - update_mask(mask, bit_struct->get_member_type(i)->as()->get_num_bits(), - bit_offset); + update_mask( + mask, + bit_struct->get_member_type(i)->as()->get_num_bits(), + bit_offset); for (int c = 0; c < (int)exponent_users.size(); c++) { auto user = exponent_users[c]; @@ -396,8 +403,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( update_mask(mask, num_digit_bits, digits_bit_offset); } } - store_masked(llvm_val[stmt->ptr], physical_type, mask, - masked_val, stmt->is_atomic); + store_masked(llvm_val[stmt->ptr], physical_type, mask, masked_val, + stmt->is_atomic); } llvm::Value *CodeGenLLVM::extract_exponent_from_f32(llvm::Value *f) {