diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 419dcc6d9368b..869ce72674b6a 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -406,8 +406,8 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { } } } else if (!is_real(from) && !is_real(to)) { - llvm_val[stmt] = builder->CreateIntCast(llvm_val[stmt->operand], - llvm_type(to), is_signed(from)); + llvm_val[stmt] = builder->CreateIntCast( + llvm_val[stmt->operand], tlctx->get_data_type(to), is_signed(from)); } } else if (stmt->op_type == UnaryOpType::cast_bits) { TI_ASSERT(data_type_size(stmt->ret_type) == @@ -621,7 +621,8 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { } else { TI_NOT_IMPLEMENTED } - llvm_val[stmt] = builder->CreateSExt(cmp, llvm_type(PrimitiveType::i32)); + llvm_val[stmt] = + builder->CreateSExt(cmp, tlctx->get_data_type(PrimitiveType::i32)); } else { // This branch contains atan2 and pow which use runtime.cpp function for // **real** type. We don't have f16 support there so promoting to f32 is @@ -681,48 +682,11 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { } } -llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { - if (dt->is_primitive(PrimitiveTypeID::i8) || - dt->is_primitive(PrimitiveTypeID::u8)) { - return llvm::Type::getInt8Ty(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::i16) || - dt->is_primitive(PrimitiveTypeID::u16)) { - return llvm::Type::getInt16Ty(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::i32) || - dt->is_primitive(PrimitiveTypeID::u32)) { - return llvm::Type::getInt32Ty(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::i64) || - dt->is_primitive(PrimitiveTypeID::u64)) { - return llvm::Type::getInt64Ty(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::u1)) { - return llvm::Type::getInt1Ty(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::f32)) { - return llvm::Type::getFloatTy(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::f64)) { - return llvm::Type::getDoubleTy(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::f16)) { - return llvm::Type::getHalfTy(*llvm_context); - } else if (dt->is()) { - TI_ASSERT_INFO(kernel->program->config.real_matrix, - "Real matrix not enabled but got TensorType"); - auto tensor_type = dt->cast(); - auto element_type = llvm_type(tensor_type->get_element_type()); - return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), - /*scalable=*/false); - } else { - TI_NOT_IMPLEMENTED; - } - return nullptr; -} - -llvm::Type *TaskCodeGenLLVM::llvm_ptr_type(DataType dt) { - return llvm::PointerType::get(llvm_type(dt), 0); -} - void TaskCodeGenLLVM::visit(TernaryOpStmt *stmt) { TI_ASSERT(stmt->op_type == TernaryOpType::select); llvm_val[stmt] = builder->CreateSelect( - builder->CreateTrunc(llvm_val[stmt->op1], llvm_type(PrimitiveType::u1)), + builder->CreateTrunc(llvm_val[stmt->op1], + tlctx->get_data_type(PrimitiveType::u1)), llvm_val[stmt->op2], llvm_val[stmt->op3]); } @@ -1030,7 +994,7 @@ void TaskCodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { BasicBlock::Create(*llvm_context, "for_loop_test", func); #ifdef TI_LLVM_15 - auto loop_var_ty = llvm_type(PrimitiveType::i32); + auto loop_var_ty = tlctx->get_data_type(PrimitiveType::i32); #endif auto loop_var = create_entry_block_alloca(PrimitiveType::i32); @@ -1195,7 +1159,7 @@ void TaskCodeGenLLVM::visit(LocalLoadStmt *stmt) { if (auto *alloc = llvm::dyn_cast(val)) ptr_ty = alloc->getAllocatedType(); if (!ptr_ty && stmt->src->element_type().is_pointer()) { - ptr_ty = llvm_type(stmt->src->element_type().ptr_removed()); + ptr_ty = tlctx->get_data_type(stmt->src->element_type().ptr_removed()); } TI_ASSERT(ptr_ty); llvm_val[stmt] = builder->CreateLoad(ptr_ty, llvm_val[stmt->src]); @@ -1298,12 +1262,14 @@ llvm::Value *TaskCodeGenLLVM::quant_type_atomic(AtomicOpStmt *stmt) { if (auto qit = dst_type->cast()) { return atomic_add_quant_int( llvm_val[stmt->dest], - llvm_type(stmt->dest->as()->input_snode->physical_type), qit, - llvm_val[stmt->val], is_signed(stmt->val->ret_type)); + tlctx->get_data_type( + stmt->dest->as()->input_snode->physical_type), + qit, llvm_val[stmt->val], is_signed(stmt->val->ret_type)); } else if (auto qfxt = dst_type->cast()) { return atomic_add_quant_fixed( llvm_val[stmt->dest], - llvm_type(stmt->dest->as()->input_snode->physical_type), + tlctx->get_data_type( + stmt->dest->as()->input_snode->physical_type), qfxt, llvm_val[stmt->val]); } else { return nullptr; @@ -1464,11 +1430,13 @@ void TaskCodeGenLLVM::visit(GlobalStoreStmt *stmt) { pointee_type->to_string()); } if (auto qit = pointee_type->cast()) { - store_quant_int(llvm_val[stmt->dest], llvm_type(snode->physical_type), - qit, llvm_val[stmt->val], true); + store_quant_int(llvm_val[stmt->dest], + tlctx->get_data_type(snode->physical_type), qit, + llvm_val[stmt->val], true); } else if (auto qfxt = pointee_type->cast()) { - store_quant_fixed(llvm_val[stmt->dest], llvm_type(snode->physical_type), - qfxt, llvm_val[stmt->val], true); + store_quant_fixed(llvm_val[stmt->dest], + tlctx->get_data_type(snode->physical_type), qfxt, + llvm_val[stmt->val], true); } else { TI_NOT_IMPLEMENTED; } @@ -1489,7 +1457,8 @@ void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt, if (ptr_type->is_bit_pointer()) { auto val_type = ptr_type->get_pointee_type(); auto get_ch = stmt->src->as(); - auto physical_type = llvm_type(get_ch->input_snode->physical_type); + auto physical_type = + tlctx->get_data_type(get_ch->input_snode->physical_type); auto [byte_ptr, bit_offset] = load_bit_ptr(ptr); auto physical_value = should_cache_as_read_only ? create_intrinsic_load(byte_ptr, physical_type) @@ -1510,7 +1479,8 @@ void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt, } else { // Byte pointer case. if (should_cache_as_read_only) { - llvm_val[stmt] = create_intrinsic_load(ptr, llvm_type(stmt->ret_type)); + llvm_val[stmt] = + create_intrinsic_load(ptr, tlctx->get_data_type(stmt->ret_type)); } else { llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), ptr); @@ -1892,7 +1862,7 @@ std::tuple TaskCodeGenLLVM::get_range_for_bounds( begin_stmt->accept(this); begin = builder->CreateLoad( #ifdef TI_LLVM_15 - llvm_type(PrimitiveType::i32), + tlctx->get_data_type(PrimitiveType::i32), #endif llvm_val[begin_stmt.get()]); } @@ -1904,7 +1874,7 @@ std::tuple TaskCodeGenLLVM::get_range_for_bounds( end_stmt->accept(this); end = builder->CreateLoad( #ifdef TI_LLVM_15 - llvm_type(PrimitiveType::i32), + tlctx->get_data_type(PrimitiveType::i32), #endif llvm_val[end_stmt.get()]); } diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index dab60a12c8830..9d952f29f666b 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -152,10 +152,6 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { static std::string get_runtime_snode_name(SNode *snode); - llvm::Type *llvm_type(DataType dt); - - llvm::Type *llvm_ptr_type(DataType dt); - void visit(Block *stmt_list) override; void visit(AllocaStmt *stmt) override; diff --git a/taichi/codegen/llvm/codegen_llvm_quant.cpp b/taichi/codegen/llvm/codegen_llvm_quant.cpp index 15ae796137ef9..6bd34756a5930 100644 --- a/taichi/codegen/llvm/codegen_llvm_quant.cpp +++ b/taichi/codegen/llvm/codegen_llvm_quant.cpp @@ -47,8 +47,9 @@ llvm::Value *TaskCodeGenLLVM::to_quant_fixed(llvm::Value *real, // Compute int(real * (1.0 / scale) + 0.5) auto compute_type = qfxt->get_compute_type(); auto s = builder->CreateFPCast(tlctx->get_constant(1.0 / qfxt->get_scale()), - llvm_type(compute_type)); - auto input_real = builder->CreateFPCast(real, llvm_type(compute_type)); + tlctx->get_data_type(compute_type)); + auto input_real = + builder->CreateFPCast(real, tlctx->get_data_type(compute_type)); auto scaled = builder->CreateFMul(input_real, s); // Add/minus the 0.5 offset for rounding @@ -58,9 +59,11 @@ llvm::Value *TaskCodeGenLLVM::to_quant_fixed(llvm::Value *real, auto qit = qfxt->get_digits_type()->as(); if (qit->get_is_signed()) { - return builder->CreateFPToSI(scaled, llvm_type(qit->get_compute_type())); + return builder->CreateFPToSI(scaled, + tlctx->get_data_type(qit->get_compute_type())); } else { - return builder->CreateFPToUI(scaled, llvm_type(qit->get_compute_type())); + return builder->CreateFPToUI(scaled, + tlctx->get_data_type(qit->get_compute_type())); } } @@ -143,7 +146,7 @@ llvm::Value *TaskCodeGenLLVM::quant_int_or_quant_fixed_to_bits( void TaskCodeGenLLVM::visit(BitStructStoreStmt *stmt) { auto bit_struct = stmt->get_bit_struct(); - auto physical_type = llvm_type(bit_struct->get_physical_type()); + auto physical_type = tlctx->get_data_type(bit_struct->get_physical_type()); int num_non_exponent_children = 0; for (int i = 0; i < bit_struct->get_num_members(); i++) { @@ -288,7 +291,7 @@ void TaskCodeGenLLVM::store_quant_floats_with_shared_exponents( BitStructStoreStmt *stmt) { // handle each exponent separately auto bit_struct = stmt->get_bit_struct(); - auto physical_type = llvm_type(bit_struct->get_physical_type()); + auto physical_type = tlctx->get_data_type(bit_struct->get_physical_type()); auto physical_value = builder->CreateLoad(physical_type, llvm_val[stmt->ptr]); // fuse all stores into a masked store llvm::Value *masked_val = nullptr; @@ -469,7 +472,8 @@ llvm::Value *TaskCodeGenLLVM::extract_quant_int(llvm::Value *physical_value, else step2 = builder->CreateLShr(step1, right); - return builder->CreateIntCast(step2, llvm_type(qit->get_compute_type()), + return builder->CreateIntCast(step2, + tlctx->get_data_type(qit->get_compute_type()), qit->get_is_signed()); } @@ -479,12 +483,12 @@ llvm::Value *TaskCodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits, llvm::Value *cast = nullptr; auto compute_type = qfxt->get_compute_type()->as(); if (qfxt->get_is_signed()) { - cast = builder->CreateSIToFP(digits, llvm_type(compute_type)); + cast = builder->CreateSIToFP(digits, tlctx->get_data_type(compute_type)); } else { - cast = builder->CreateUIToFP(digits, llvm_type(compute_type)); + cast = builder->CreateUIToFP(digits, tlctx->get_data_type(compute_type)); } llvm::Value *s = tlctx->get_constant(qfxt->get_scale()); - s = builder->CreateFPCast(s, llvm_type(compute_type)); + s = builder->CreateFPCast(s, tlctx->get_data_type(compute_type)); return builder->CreateFMul(cast, s); } diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 8e99f32503b58..f1f127540b80f 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -120,38 +120,36 @@ TaichiLLVMContext::~TaichiLLVMContext() { llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { auto ctx = get_this_thread_context(); - if (dt->is_primitive(PrimitiveTypeID::i32)) { - return llvm::Type::getInt32Ty(*ctx); - } else if (dt->is_primitive(PrimitiveTypeID::i8)) { + if (dt->is_primitive(PrimitiveTypeID::i8) || + dt->is_primitive(PrimitiveTypeID::u8)) { return llvm::Type::getInt8Ty(*ctx); - } else if (dt->is_primitive(PrimitiveTypeID::i16)) { + } else if (dt->is_primitive(PrimitiveTypeID::i16) || + dt->is_primitive(PrimitiveTypeID::u16)) { return llvm::Type::getInt16Ty(*ctx); - } else if (dt->is_primitive(PrimitiveTypeID::i64)) { + } else if (dt->is_primitive(PrimitiveTypeID::i32) || + dt->is_primitive(PrimitiveTypeID::u32)) { + return llvm::Type::getInt32Ty(*ctx); + } else if (dt->is_primitive(PrimitiveTypeID::i64) || + dt->is_primitive(PrimitiveTypeID::u64)) { return llvm::Type::getInt64Ty(*ctx); + } else if (dt->is_primitive(PrimitiveTypeID::u1)) { + return llvm::Type::getInt1Ty(*ctx); } else if (dt->is_primitive(PrimitiveTypeID::f32)) { return llvm::Type::getFloatTy(*ctx); } else if (dt->is_primitive(PrimitiveTypeID::f64)) { return llvm::Type::getDoubleTy(*ctx); - } else if (dt->is_primitive(PrimitiveTypeID::u8)) { - return llvm::Type::getInt8Ty(*ctx); - } else if (dt->is_primitive(PrimitiveTypeID::u16)) { - return llvm::Type::getInt16Ty(*ctx); - } else if (dt->is_primitive(PrimitiveTypeID::u32)) { - return llvm::Type::getInt32Ty(*ctx); - } else if (dt->is_primitive(PrimitiveTypeID::u64)) { - return llvm::Type::getInt64Ty(*ctx); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); } else if (dt->is()) { TI_ASSERT_INFO(config_->real_matrix, "Real matrix not enabled but got TensorType"); - auto vectorty = dt->as(); - auto dtype = this->get_data_type(vectorty->get_element_type()); - return llvm::VectorType::get(dtype, vectorty->get_num_elements(), + auto tensor_type = dt->cast(); + auto element_type = get_data_type(tensor_type->get_element_type()); + return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), /*scalable=*/false); } else { TI_INFO(data_type_name(dt)); - TI_NOT_IMPLEMENTED + TI_NOT_IMPLEMENTED; } }