From bb1985be40e9016499ce8180f6ca14a6acf98ed9 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Wed, 22 Jun 2022 17:56:32 +0800 Subject: [PATCH 1/2] [type] [refactor] Decouple quant from SNode 3/n: Extend bit pointers --- taichi/backends/cuda/codegen_cuda.cpp | 4 +- taichi/codegen/codegen_llvm.cpp | 91 +++++++--------------- taichi/codegen/codegen_llvm.h | 31 ++------ taichi/codegen/codegen_llvm_quant.cpp | 108 +++++++++----------------- 4 files changed, 76 insertions(+), 158 deletions(-) diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index 06cde97c0b4d8..1269d076a178f 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -540,10 +540,10 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { auto val_type = ptr_type->get_pointee_type(); if (auto qit = val_type->cast()) { dtype = get_ch->input_snode->physical_type; - auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->src]); + auto [data_ptr, bit_offset] = load_bit_ptr(llvm_val[stmt->src]); data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype)); auto data = create_intrinsic_load(dtype, data_ptr); - llvm_val[stmt] = extract_quant_int(data, bit_offset, qit, dtype); + llvm_val[stmt] = extract_quant_int(data, bit_offset, qit); } else { // TODO: support __ldg TI_ASSERT(val_type->is() || diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 7444fcf0d22b8..8a073661910aa 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1195,11 +1195,9 @@ llvm::Value *CodeGenLLVM::quant_type_atomic(AtomicOpStmt *stmt) { auto dst_type = stmt->dest->ret_type->as()->get_pointee_type(); if (auto qit = dst_type->cast()) { - return atomic_add_quant_int( - stmt, qit, stmt->dest->as()->input_snode->physical_type); + return atomic_add_quant_int(stmt, qit); } else if (auto qfxt = dst_type->cast()) { - return atomic_add_quant_fixed( - stmt, qfxt, stmt->dest->as()->input_snode->physical_type); + return atomic_add_quant_fixed(stmt, qfxt); } else { return nullptr; } @@ -1354,7 +1352,6 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { } } store_quant_int(llvm_val[stmt->dest], pointee_type->as(), - stmt->dest->as()->input_snode->physical_type, llvm_val[stmt->val], true); } else { builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); @@ -1368,9 +1365,7 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { if (ptr_type->is_bit_pointer()) { auto val_type = ptr_type->get_pointee_type(); if (auto qit = val_type->cast()) { - llvm_val[stmt] = load_quant_int( - llvm_val[stmt->src], qit, - stmt->src->as()->input_snode->physical_type); + llvm_val[stmt] = load_quant_int(llvm_val[stmt->src], qit); } else { TI_ASSERT(val_type->is() || val_type->is()); @@ -1479,63 +1474,35 @@ void CodeGenLLVM::visit(LinearizeStmt *stmt) { void CodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED} -llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base, - llvm::Value *bit_offset) { - // 1. get the bit pointer LLVM struct - // struct bit_pointer { - // i8* byte_ptr; - // i32 offset; +llvm::Value *CodeGenLLVM::create_bit_ptr(llvm::Value *byte_ptr, llvm::Value *bit_offset) { + // 1. define the bit pointer struct (X=8/16/32/64) + // struct bit_pointer_X { + // iX* byte_ptr; + // i32 bit_offset; // }; - auto struct_type = llvm::StructType::get( - *llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context), - llvm::Type::getInt32Ty(*llvm_context)}); - // 2. allocate the bit pointer struct - auto bit_ptr_struct = create_entry_block_alloca(struct_type); - // 3. store `byte_ptr_base` into `bit_ptr_struct` (if provided) - if (byte_ptr_base) { - auto byte_ptr = builder->CreateBitCast( - byte_ptr_base, llvm::PointerType::getInt8PtrTy(*llvm_context)); - builder->CreateStore( - byte_ptr, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0), - tlctx->get_constant(0)})); - } - // 4. store `offset` in `bit_ptr_struct` (if provided) - if (bit_offset) { - builder->CreateStore( - bit_offset, - builder->CreateGEP(bit_ptr_struct, - {tlctx->get_constant(0), tlctx->get_constant(1)})); - } - return bit_ptr_struct; -} - -llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *input_bit_ptr, - int bit_offset_delta) { - auto byte_ptr_base = builder->CreateLoad(builder->CreateGEP( - input_bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)})); - auto input_offset = builder->CreateLoad(builder->CreateGEP( - input_bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)})); - auto new_bit_offset = - builder->CreateAdd(input_offset, tlctx->get_constant(bit_offset_delta)); - return create_bit_ptr_struct(byte_ptr_base, new_bit_offset); -} - -std::tuple CodeGenLLVM::load_bit_pointer( - llvm::Value *ptr) { - // 1. load byte pointer - auto byte_ptr_in_bit_struct = - builder->CreateGEP(ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}); - auto byte_ptr = builder->CreateLoad(byte_ptr_in_bit_struct); - TI_ASSERT(byte_ptr->getType()->getPointerElementType()->isIntegerTy(8)); - - // 2. load bit offset - auto bit_offset_in_bit_struct = - builder->CreateGEP(ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}); - auto bit_offset = builder->CreateLoad(bit_offset_in_bit_struct); TI_ASSERT(bit_offset->getType()->isIntegerTy(32)); + auto struct_type = llvm::StructType::get(*llvm_context, {byte_ptr->getType(), bit_offset->getType()}); + // 2. allocate the bit pointer struct + auto bit_ptr = create_entry_block_alloca(struct_type); + // 3. store `byte_ptr` + builder->CreateStore(byte_ptr, builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)})); + // 4. store `bit_offset + builder->CreateStore(bit_offset,builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)})); + return bit_ptr; +} + +std::tuple CodeGenLLVM::load_bit_ptr(llvm::Value *bit_ptr) { + auto byte_ptr = builder->CreateLoad(builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)})); + auto bit_offset = builder->CreateLoad(builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)})); return std::make_tuple(byte_ptr, bit_offset); } +llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *bit_ptr, int bit_offset_delta) { + auto [byte_ptr, bit_offset] = load_bit_ptr(bit_ptr); + auto new_bit_offset = builder->CreateAdd(bit_offset, tlctx->get_constant(bit_offset_delta)); + return create_bit_ptr(byte_ptr, new_bit_offset); +} + void CodeGenLLVM::visit(SNodeLookupStmt *stmt) { llvm::Value *parent = nullptr; parent = llvm_val[stmt->input_snode]; @@ -1560,7 +1527,7 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) { snode->dt->as()->get_element_num_bits(); auto offset = tlctx->get_constant(element_num_bits); offset = builder->CreateMul(offset, llvm_val[stmt->input_index]); - llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_snode], offset); + llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_snode], offset); } else { TI_INFO(snode_type_name(snode->type)); TI_NOT_IMPLEMENTED @@ -1575,7 +1542,7 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { auto bit_offset = bit_struct->get_member_bit_offset( stmt->input_snode->child_id(stmt->output_snode)); auto offset = tlctx->get_constant(bit_offset); - llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset); + llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_ptr], offset); } else { auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(), {builder->CreateBitCast( diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 76b843bf4be20..9b9dd88b7f41e 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -219,13 +219,9 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(SNodeOpStmt *stmt) override; - llvm::Value *atomic_add_quant_fixed(AtomicOpStmt *stmt, - QuantFixedType *qfxt, - Type *physical_type); + llvm::Value *atomic_add_quant_fixed(AtomicOpStmt *stmt, QuantFixedType *qfxt); - llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt, - QuantIntType *qit, - Type *physical_type); + llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt, QuantIntType *qit); llvm::Value *quant_fixed_to_quant_int(QuantFixedType *qfxt, QuantIntType *qit, @@ -252,20 +248,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void store_quant_int(llvm::Value *bit_ptr, QuantIntType *qit, - Type *physical_type, - llvm::Value *value, - bool atomic); - - void store_quant_int(llvm::Value *byte_ptr, - llvm::Value *bit_offset, - QuantIntType *qit, - Type *physical_type, llvm::Value *value, bool atomic); void store_masked(llvm::Value *byte_ptr, uint64 mask, - Type *physical_type, llvm::Value *value, bool atomic); @@ -283,13 +270,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { SNode *digits_snode); llvm::Value *load_quant_int(llvm::Value *ptr, - QuantIntType *qit, - Type *physical_type); + QuantIntType *qit); llvm::Value *extract_quant_int(llvm::Value *physical_value, llvm::Value *bit_offset, - QuantIntType *qit, - Type *physical_type); + QuantIntType *qit); llvm::Value *reconstruct_quant_fixed(llvm::Value *digits, QuantFixedType *qfxt); @@ -297,7 +282,6 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr, llvm::Value *exponent_bit_ptr, QuantFloatType *qflt, - Type *physical_type, bool shared_exponent); llvm::Value *reconstruct_quant_float(llvm::Value *input_digits, @@ -319,12 +303,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(IntegerOffsetStmt *stmt) override; - llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base = nullptr, - llvm::Value *bit_offset = nullptr); + llvm::Value *create_bit_ptr(llvm::Value *byte_ptr, llvm::Value *bit_offset); - llvm::Value *offset_bit_ptr(llvm::Value *input_bit_ptr, int bit_offset_delta); + std::tuple load_bit_ptr(llvm::Value *bit_ptr); - std::tuple load_bit_pointer(llvm::Value *ptr); + llvm::Value *offset_bit_ptr(llvm::Value *bit_ptr, int bit_offset_delta); void visit(SNodeLookupStmt *stmt) override; diff --git a/taichi/codegen/codegen_llvm_quant.cpp b/taichi/codegen/codegen_llvm_quant.cpp index aad72f2fe42ad..9e07e10d85663 100644 --- a/taichi/codegen/codegen_llvm_quant.cpp +++ b/taichi/codegen/codegen_llvm_quant.cpp @@ -18,28 +18,22 @@ inline void update_mask(uint64 &mask, uint32 num_bits, uint32 offset) { } // namespace llvm::Value *CodeGenLLVM::atomic_add_quant_int(AtomicOpStmt *stmt, - QuantIntType *qit, - Type *physical_type) { - auto [byte_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->dest]); - return create_call( - fmt::format("atomic_add_partial_bits_b{}", data_type_bits(physical_type)), - {builder->CreateBitCast(byte_ptr, llvm_ptr_type(physical_type)), - bit_offset, tlctx->get_constant(qit->get_num_bits()), - builder->CreateIntCast(llvm_val[stmt->val], llvm_type(physical_type), - is_signed(stmt->val->ret_type))}); + QuantIntType *qit) { + auto [byte_ptr, bit_offset] = load_bit_ptr(llvm_val[stmt->dest]); + auto physical_type = byte_ptr->getType()->getPointerElementType(); + return create_call(fmt::format("atomic_add_partial_bits_b{}", physical_type->getIntegerBitWidth()), + {byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), builder->CreateIntCast(llvm_val[stmt->val], physical_type, is_signed(stmt->val->ret_type))}); } llvm::Value *CodeGenLLVM::atomic_add_quant_fixed(AtomicOpStmt *stmt, - QuantFixedType *qfxt, - Type *physical_type) { - auto [byte_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->dest]); + QuantFixedType *qfxt) { + auto [byte_ptr, bit_offset] = load_bit_ptr(llvm_val[stmt->dest]); + auto physical_type = byte_ptr->getType()->getPointerElementType(); auto qit = qfxt->get_digits_type()->as(); auto val_store = quant_fixed_to_quant_int(qfxt, qit, llvm_val[stmt->val]); - val_store = builder->CreateSExt(val_store, llvm_type(physical_type)); - return create_call( - fmt::format("atomic_add_partial_bits_b{}", data_type_bits(physical_type)), - {builder->CreateBitCast(byte_ptr, llvm_ptr_type(physical_type)), - bit_offset, tlctx->get_constant(qit->get_num_bits()), val_store}); + val_store = builder->CreateSExt(val_store, physical_type); + return create_call(fmt::format("atomic_add_partial_bits_b{}", physical_type->getIntegerBitWidth()), + {byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), val_store}); } llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(QuantFixedType *qfxt, @@ -69,48 +63,34 @@ llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(QuantFixedType *qfxt, void CodeGenLLVM::store_quant_int(llvm::Value *bit_ptr, QuantIntType *qit, - Type *physical_type, - llvm::Value *value, - bool atomic) { - auto [byte_ptr, bit_offset] = load_bit_pointer(bit_ptr); - store_quant_int(byte_ptr, bit_offset, qit, physical_type, value, atomic); -} - -void CodeGenLLVM::store_quant_int(llvm::Value *byte_ptr, - llvm::Value *bit_offset, - QuantIntType *qit, - Type *physical_type, llvm::Value *value, bool atomic) { + auto [byte_ptr, bit_offset] = load_bit_ptr(bit_ptr); + auto physical_type = byte_ptr->getType()->getPointerElementType(); // TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers. // Try to support 8/16-bit physical types. - create_call(fmt::format("{}set_partial_bits_b{}", atomic ? "atomic_" : "", - data_type_bits(physical_type)), - {builder->CreateBitCast(byte_ptr, llvm_ptr_type(physical_type)), - bit_offset, tlctx->get_constant(qit->get_num_bits()), - builder->CreateIntCast(value, llvm_type(physical_type), false)}); + create_call(fmt::format("{}set_partial_bits_b{}", atomic ? "atomic_" : "", physical_type->getIntegerBitWidth()), + {byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), + builder->CreateIntCast(value, physical_type, false)}); } void CodeGenLLVM::store_masked(llvm::Value *byte_ptr, uint64 mask, - Type *physical_type, llvm::Value *value, bool atomic) { if (!mask) { // do not store anything return; } - uint64 full_mask = (~(uint64)0) >> (64 - data_type_bits(physical_type)); + auto physical_type = byte_ptr->getType()->getPointerElementType(); + uint64 full_mask = (~(uint64)0) >> (64 - physical_type->getIntegerBitWidth()); if ((!atomic || prog->config.quant_opt_atomic_demotion) && ((mask & full_mask) == full_mask)) { builder->CreateStore(value, byte_ptr); return; } - create_call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "", - data_type_bits(physical_type)), - {builder->CreateBitCast(byte_ptr, llvm_ptr_type(physical_type)), - tlctx->get_constant(mask), - builder->CreateIntCast(value, llvm_type(physical_type), false)}); + create_call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "", physical_type->getIntegerBitWidth()), + {byte_ptr, tlctx->get_constant(mask), builder->CreateIntCast(value, physical_type, false)}); } llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent, @@ -296,8 +276,7 @@ void CodeGenLLVM::visit(BitStructStoreStmt *stmt) { } update_mask(mask, qit->get_num_bits(), ch->bit_offset); } - store_masked(llvm_val[stmt->ptr], mask, bit_struct_physical_type, - bit_struct_val, stmt->is_atomic); + store_masked(llvm_val[stmt->ptr], mask, bit_struct_val, stmt->is_atomic); } } @@ -410,8 +389,7 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents( update_mask(mask, num_digit_bits, digits_bit_offset); } } - store_masked(llvm_val[stmt->ptr], mask, bit_struct_physical_type, masked_val, - stmt->is_atomic); + store_masked(llvm_val[stmt->ptr], mask, masked_val, stmt->is_atomic); } llvm::Value *CodeGenLLVM::extract_exponent_from_f32(llvm::Value *f) { @@ -461,43 +439,39 @@ llvm::Value *CodeGenLLVM::extract_quant_float(llvm::Value *local_bit_struct, auto qflt = digits_snode->dt->as(); auto exponent_type = qflt->get_exponent_type()->as(); auto digits_type = qflt->get_digits_type()->as(); - auto physical_type = digits_snode->parent->physical_type; auto digits = extract_quant_int(local_bit_struct, tlctx->get_constant(digits_snode->bit_offset), - digits_type, physical_type); + digits_type); auto exponent = extract_quant_int( local_bit_struct, - tlctx->get_constant(digits_snode->exp_snode->bit_offset), exponent_type, - physical_type); + tlctx->get_constant(digits_snode->exp_snode->bit_offset), exponent_type); return reconstruct_quant_float(digits, exponent, qflt, digits_snode->owns_shared_exponent); } llvm::Value *CodeGenLLVM::load_quant_int(llvm::Value *ptr, - QuantIntType *qit, - Type *physical_type) { - auto [byte_ptr, bit_offset] = load_bit_pointer(ptr); - auto bit_level_container = builder->CreateLoad( - builder->CreateBitCast(byte_ptr, llvm_ptr_type(physical_type))); - return extract_quant_int(bit_level_container, bit_offset, qit, physical_type); + QuantIntType *qit) { + auto [byte_ptr, bit_offset] = load_bit_ptr(ptr); + auto physical_value = builder->CreateLoad(byte_ptr); + return extract_quant_int(physical_value, bit_offset, qit); } llvm::Value *CodeGenLLVM::extract_quant_int(llvm::Value *physical_value, llvm::Value *bit_offset, - QuantIntType *qit, - Type *physical_type) { + QuantIntType *qit) { + auto physical_type = physical_value->getType(); // bit shifting // first left shift `physical_type - (offset + num_bits)` // then right shift `physical_type - num_bits` auto bit_end = builder->CreateAdd(bit_offset, tlctx->get_constant(qit->get_num_bits())); auto left = builder->CreateSub( - tlctx->get_constant(data_type_bits(physical_type)), bit_end); + tlctx->get_constant(physical_type->getIntegerBitWidth()), bit_end); auto right = - builder->CreateSub(tlctx->get_constant(data_type_bits(physical_type)), + builder->CreateSub(tlctx->get_constant(physical_type->getIntegerBitWidth()), tlctx->get_constant(qit->get_num_bits())); - left = builder->CreateIntCast(left, physical_value->getType(), false); - right = builder->CreateIntCast(right, physical_value->getType(), false); + left = builder->CreateIntCast(left, physical_type, false); + right = builder->CreateIntCast(right, physical_type, false); auto step1 = builder->CreateShl(physical_value, left); llvm::Value *step2 = nullptr; @@ -528,14 +502,11 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits, llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_bit_ptr, llvm::Value *exponent_bit_ptr, QuantFloatType *qflt, - Type *physical_type, bool shared_exponent) { auto digits = load_quant_int(digits_bit_ptr, - qflt->get_digits_type()->as(), - physical_type); + qflt->get_digits_type()->as()); auto exponent_val = load_quant_int( - exponent_bit_ptr, qflt->get_exponent_type()->as(), - physical_type); + exponent_bit_ptr, qflt->get_exponent_type()->as()); return reconstruct_quant_float(digits, exponent_val, qflt, shared_exponent); } @@ -641,7 +612,6 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_float( llvm::Value *CodeGenLLVM::load_quant_fixed_or_quant_float(Stmt *ptr_stmt) { auto ptr = ptr_stmt->as(); auto load_type = ptr->ret_type->as()->get_pointee_type(); - auto physical_type = ptr->input_snode->physical_type; if (auto qflt = load_type->cast()) { TI_ASSERT(ptr->width() == 1); auto digits_bit_ptr = llvm_val[ptr]; @@ -651,13 +621,11 @@ llvm::Value *CodeGenLLVM::load_quant_fixed_or_quant_float(Stmt *ptr_stmt) { TI_ASSERT(digits_snode->parent == exponent_snode->parent); auto exponent_bit_ptr = offset_bit_ptr( digits_bit_ptr, exponent_snode->bit_offset - digits_snode->bit_offset); - return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt, - physical_type, digits_snode->owns_shared_exponent); + return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt, digits_snode->owns_shared_exponent); } else { auto qfxt = load_type->as(); auto digits = load_quant_int(llvm_val[ptr], - qfxt->get_digits_type()->as(), - physical_type); + qfxt->get_digits_type()->as()); return reconstruct_quant_fixed(digits, qfxt); } } From def861f6b87794161e5df11b4278ddc7cf81b71d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Jun 2022 09:59:35 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/codegen_llvm.cpp | 29 +++++++++++++++------- taichi/codegen/codegen_llvm.h | 3 +-- taichi/codegen/codegen_llvm_quant.cpp | 35 +++++++++++++++++---------- 3 files changed, 43 insertions(+), 24 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 8a073661910aa..90458f53b7d6c 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1474,32 +1474,43 @@ void CodeGenLLVM::visit(LinearizeStmt *stmt) { void CodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED} -llvm::Value *CodeGenLLVM::create_bit_ptr(llvm::Value *byte_ptr, llvm::Value *bit_offset) { +llvm::Value *CodeGenLLVM::create_bit_ptr(llvm::Value *byte_ptr, + llvm::Value *bit_offset) { // 1. define the bit pointer struct (X=8/16/32/64) // struct bit_pointer_X { // iX* byte_ptr; // i32 bit_offset; // }; TI_ASSERT(bit_offset->getType()->isIntegerTy(32)); - auto struct_type = llvm::StructType::get(*llvm_context, {byte_ptr->getType(), bit_offset->getType()}); + auto struct_type = llvm::StructType::get( + *llvm_context, {byte_ptr->getType(), bit_offset->getType()}); // 2. allocate the bit pointer struct auto bit_ptr = create_entry_block_alloca(struct_type); // 3. store `byte_ptr` - builder->CreateStore(byte_ptr, builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)})); + builder->CreateStore( + byte_ptr, builder->CreateGEP( + bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)})); // 4. store `bit_offset - builder->CreateStore(bit_offset,builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)})); + builder->CreateStore(bit_offset, + builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), + tlctx->get_constant(1)})); return bit_ptr; } -std::tuple CodeGenLLVM::load_bit_ptr(llvm::Value *bit_ptr) { - auto byte_ptr = builder->CreateLoad(builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)})); - auto bit_offset = builder->CreateLoad(builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)})); +std::tuple CodeGenLLVM::load_bit_ptr( + llvm::Value *bit_ptr) { + auto byte_ptr = builder->CreateLoad(builder->CreateGEP( + bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)})); + auto bit_offset = builder->CreateLoad(builder->CreateGEP( + bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)})); return std::make_tuple(byte_ptr, bit_offset); } -llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *bit_ptr, int bit_offset_delta) { +llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *bit_ptr, + int bit_offset_delta) { auto [byte_ptr, bit_offset] = load_bit_ptr(bit_ptr); - auto new_bit_offset = builder->CreateAdd(bit_offset, tlctx->get_constant(bit_offset_delta)); + auto new_bit_offset = + builder->CreateAdd(bit_offset, tlctx->get_constant(bit_offset_delta)); return create_bit_ptr(byte_ptr, new_bit_offset); } diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 9b9dd88b7f41e..5409409fc9f35 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -269,8 +269,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *extract_quant_float(llvm::Value *local_bit_struct, SNode *digits_snode); - llvm::Value *load_quant_int(llvm::Value *ptr, - QuantIntType *qit); + llvm::Value *load_quant_int(llvm::Value *ptr, QuantIntType *qit); llvm::Value *extract_quant_int(llvm::Value *physical_value, llvm::Value *bit_offset, diff --git a/taichi/codegen/codegen_llvm_quant.cpp b/taichi/codegen/codegen_llvm_quant.cpp index 9e07e10d85663..4b221e4fb9f02 100644 --- a/taichi/codegen/codegen_llvm_quant.cpp +++ b/taichi/codegen/codegen_llvm_quant.cpp @@ -21,8 +21,12 @@ llvm::Value *CodeGenLLVM::atomic_add_quant_int(AtomicOpStmt *stmt, QuantIntType *qit) { auto [byte_ptr, bit_offset] = load_bit_ptr(llvm_val[stmt->dest]); auto physical_type = byte_ptr->getType()->getPointerElementType(); - return create_call(fmt::format("atomic_add_partial_bits_b{}", physical_type->getIntegerBitWidth()), - {byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), builder->CreateIntCast(llvm_val[stmt->val], physical_type, is_signed(stmt->val->ret_type))}); + return create_call( + fmt::format("atomic_add_partial_bits_b{}", + physical_type->getIntegerBitWidth()), + {byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), + builder->CreateIntCast(llvm_val[stmt->val], physical_type, + is_signed(stmt->val->ret_type))}); } llvm::Value *CodeGenLLVM::atomic_add_quant_fixed(AtomicOpStmt *stmt, @@ -32,8 +36,10 @@ llvm::Value *CodeGenLLVM::atomic_add_quant_fixed(AtomicOpStmt *stmt, auto qit = qfxt->get_digits_type()->as(); auto val_store = quant_fixed_to_quant_int(qfxt, qit, llvm_val[stmt->val]); val_store = builder->CreateSExt(val_store, physical_type); - return create_call(fmt::format("atomic_add_partial_bits_b{}", physical_type->getIntegerBitWidth()), - {byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), val_store}); + return create_call(fmt::format("atomic_add_partial_bits_b{}", + physical_type->getIntegerBitWidth()), + {byte_ptr, bit_offset, + tlctx->get_constant(qit->get_num_bits()), val_store}); } llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(QuantFixedType *qfxt, @@ -69,7 +75,8 @@ void CodeGenLLVM::store_quant_int(llvm::Value *bit_ptr, auto physical_type = byte_ptr->getType()->getPointerElementType(); // TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers. // Try to support 8/16-bit physical types. - create_call(fmt::format("{}set_partial_bits_b{}", atomic ? "atomic_" : "", physical_type->getIntegerBitWidth()), + create_call(fmt::format("{}set_partial_bits_b{}", atomic ? "atomic_" : "", + physical_type->getIntegerBitWidth()), {byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), builder->CreateIntCast(value, physical_type, false)}); } @@ -89,8 +96,10 @@ void CodeGenLLVM::store_masked(llvm::Value *byte_ptr, builder->CreateStore(value, byte_ptr); return; } - create_call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "", physical_type->getIntegerBitWidth()), - {byte_ptr, tlctx->get_constant(mask), builder->CreateIntCast(value, physical_type, false)}); + create_call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "", + physical_type->getIntegerBitWidth()), + {byte_ptr, tlctx->get_constant(mask), + builder->CreateIntCast(value, physical_type, false)}); } llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent, @@ -449,8 +458,7 @@ llvm::Value *CodeGenLLVM::extract_quant_float(llvm::Value *local_bit_struct, digits_snode->owns_shared_exponent); } -llvm::Value *CodeGenLLVM::load_quant_int(llvm::Value *ptr, - QuantIntType *qit) { +llvm::Value *CodeGenLLVM::load_quant_int(llvm::Value *ptr, QuantIntType *qit) { auto [byte_ptr, bit_offset] = load_bit_ptr(ptr); auto physical_value = builder->CreateLoad(byte_ptr); return extract_quant_int(physical_value, bit_offset, qit); @@ -467,9 +475,9 @@ llvm::Value *CodeGenLLVM::extract_quant_int(llvm::Value *physical_value, builder->CreateAdd(bit_offset, tlctx->get_constant(qit->get_num_bits())); auto left = builder->CreateSub( tlctx->get_constant(physical_type->getIntegerBitWidth()), bit_end); - auto right = - builder->CreateSub(tlctx->get_constant(physical_type->getIntegerBitWidth()), - tlctx->get_constant(qit->get_num_bits())); + auto right = builder->CreateSub( + tlctx->get_constant(physical_type->getIntegerBitWidth()), + tlctx->get_constant(qit->get_num_bits())); left = builder->CreateIntCast(left, physical_type, false); right = builder->CreateIntCast(right, physical_type, false); auto step1 = builder->CreateShl(physical_value, left); @@ -621,7 +629,8 @@ llvm::Value *CodeGenLLVM::load_quant_fixed_or_quant_float(Stmt *ptr_stmt) { TI_ASSERT(digits_snode->parent == exponent_snode->parent); auto exponent_bit_ptr = offset_bit_ptr( digits_bit_ptr, exponent_snode->bit_offset - digits_snode->bit_offset); - return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt, digits_snode->owns_shared_exponent); + return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt, + digits_snode->owns_shared_exponent); } else { auto qfxt = load_type->as(); auto digits = load_quant_int(llvm_val[ptr],