Skip to content

Commit

Permalink
[refactor] [llvm] Unify llvm_type() and get_data_type() (#5927)
Browse files Browse the repository at this point in the history
* [refactor] [llvm] Unify llvm_type() and get_data_type()

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Aug 31, 2022
1 parent ef56801 commit ffee072
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 86 deletions.
80 changes: 25 additions & 55 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
}
}
} else if (!is_real(from) && !is_real(to)) {
llvm_val[stmt] = builder->CreateIntCast(llvm_val[stmt->operand],
llvm_type(to), is_signed(from));
llvm_val[stmt] = builder->CreateIntCast(
llvm_val[stmt->operand], tlctx->get_data_type(to), is_signed(from));
}
} else if (stmt->op_type == UnaryOpType::cast_bits) {
TI_ASSERT(data_type_size(stmt->ret_type) ==
Expand Down Expand Up @@ -621,7 +621,8 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
} else {
TI_NOT_IMPLEMENTED
}
llvm_val[stmt] = builder->CreateSExt(cmp, llvm_type(PrimitiveType::i32));
llvm_val[stmt] =
builder->CreateSExt(cmp, tlctx->get_data_type(PrimitiveType::i32));
} else {
// This branch contains atan2 and pow which use runtime.cpp function for
// **real** type. We don't have f16 support there so promoting to f32 is
Expand Down Expand Up @@ -681,48 +682,11 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
}
}

llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) {
if (dt->is_primitive(PrimitiveTypeID::i8) ||
dt->is_primitive(PrimitiveTypeID::u8)) {
return llvm::Type::getInt8Ty(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i16) ||
dt->is_primitive(PrimitiveTypeID::u16)) {
return llvm::Type::getInt16Ty(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::u32)) {
return llvm::Type::getInt32Ty(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i64) ||
dt->is_primitive(PrimitiveTypeID::u64)) {
return llvm::Type::getInt64Ty(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return llvm::Type::getInt1Ty(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f32)) {
return llvm::Type::getFloatTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return llvm::Type::getDoubleTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
return llvm::Type::getHalfTy(*llvm_context);
} else if (dt->is<TensorType>()) {
TI_ASSERT_INFO(kernel->program->config.real_matrix,
"Real matrix not enabled but got TensorType");
auto tensor_type = dt->cast<TensorType>();
auto element_type = llvm_type(tensor_type->get_element_type());
return llvm::VectorType::get(element_type, tensor_type->get_num_elements(),
/*scalable=*/false);
} else {
TI_NOT_IMPLEMENTED;
}
return nullptr;
}

llvm::Type *TaskCodeGenLLVM::llvm_ptr_type(DataType dt) {
return llvm::PointerType::get(llvm_type(dt), 0);
}

void TaskCodeGenLLVM::visit(TernaryOpStmt *stmt) {
TI_ASSERT(stmt->op_type == TernaryOpType::select);
llvm_val[stmt] = builder->CreateSelect(
builder->CreateTrunc(llvm_val[stmt->op1], llvm_type(PrimitiveType::u1)),
builder->CreateTrunc(llvm_val[stmt->op1],
tlctx->get_data_type(PrimitiveType::u1)),
llvm_val[stmt->op2], llvm_val[stmt->op3]);
}

Expand Down Expand Up @@ -1030,7 +994,7 @@ void TaskCodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
BasicBlock::Create(*llvm_context, "for_loop_test", func);

#ifdef TI_LLVM_15
auto loop_var_ty = llvm_type(PrimitiveType::i32);
auto loop_var_ty = tlctx->get_data_type(PrimitiveType::i32);
#endif

auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
Expand Down Expand Up @@ -1195,7 +1159,7 @@ void TaskCodeGenLLVM::visit(LocalLoadStmt *stmt) {
if (auto *alloc = llvm::dyn_cast<llvm::AllocaInst>(val))
ptr_ty = alloc->getAllocatedType();
if (!ptr_ty && stmt->src->element_type().is_pointer()) {
ptr_ty = llvm_type(stmt->src->element_type().ptr_removed());
ptr_ty = tlctx->get_data_type(stmt->src->element_type().ptr_removed());
}
TI_ASSERT(ptr_ty);
llvm_val[stmt] = builder->CreateLoad(ptr_ty, llvm_val[stmt->src]);
Expand Down Expand Up @@ -1298,12 +1262,14 @@ llvm::Value *TaskCodeGenLLVM::quant_type_atomic(AtomicOpStmt *stmt) {
if (auto qit = dst_type->cast<QuantIntType>()) {
return atomic_add_quant_int(
llvm_val[stmt->dest],
llvm_type(stmt->dest->as<GetChStmt>()->input_snode->physical_type), qit,
llvm_val[stmt->val], is_signed(stmt->val->ret_type));
tlctx->get_data_type(
stmt->dest->as<GetChStmt>()->input_snode->physical_type),
qit, llvm_val[stmt->val], is_signed(stmt->val->ret_type));
} else if (auto qfxt = dst_type->cast<QuantFixedType>()) {
return atomic_add_quant_fixed(
llvm_val[stmt->dest],
llvm_type(stmt->dest->as<GetChStmt>()->input_snode->physical_type),
tlctx->get_data_type(
stmt->dest->as<GetChStmt>()->input_snode->physical_type),
qfxt, llvm_val[stmt->val]);
} else {
return nullptr;
Expand Down Expand Up @@ -1464,11 +1430,13 @@ void TaskCodeGenLLVM::visit(GlobalStoreStmt *stmt) {
pointee_type->to_string());
}
if (auto qit = pointee_type->cast<QuantIntType>()) {
store_quant_int(llvm_val[stmt->dest], llvm_type(snode->physical_type),
qit, llvm_val[stmt->val], true);
store_quant_int(llvm_val[stmt->dest],
tlctx->get_data_type(snode->physical_type), qit,
llvm_val[stmt->val], true);
} else if (auto qfxt = pointee_type->cast<QuantFixedType>()) {
store_quant_fixed(llvm_val[stmt->dest], llvm_type(snode->physical_type),
qfxt, llvm_val[stmt->val], true);
store_quant_fixed(llvm_val[stmt->dest],
tlctx->get_data_type(snode->physical_type), qfxt,
llvm_val[stmt->val], true);
} else {
TI_NOT_IMPLEMENTED;
}
Expand All @@ -1489,7 +1457,8 @@ void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt,
if (ptr_type->is_bit_pointer()) {
auto val_type = ptr_type->get_pointee_type();
auto get_ch = stmt->src->as<GetChStmt>();
auto physical_type = llvm_type(get_ch->input_snode->physical_type);
auto physical_type =
tlctx->get_data_type(get_ch->input_snode->physical_type);
auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
auto physical_value = should_cache_as_read_only
? create_intrinsic_load(byte_ptr, physical_type)
Expand All @@ -1510,7 +1479,8 @@ void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt,
} else {
// Byte pointer case.
if (should_cache_as_read_only) {
llvm_val[stmt] = create_intrinsic_load(ptr, llvm_type(stmt->ret_type));
llvm_val[stmt] =
create_intrinsic_load(ptr, tlctx->get_data_type(stmt->ret_type));
} else {
llvm_val[stmt] =
builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), ptr);
Expand Down Expand Up @@ -1892,7 +1862,7 @@ std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::get_range_for_bounds(
begin_stmt->accept(this);
begin = builder->CreateLoad(
#ifdef TI_LLVM_15
llvm_type(PrimitiveType::i32),
tlctx->get_data_type(PrimitiveType::i32),
#endif
llvm_val[begin_stmt.get()]);
}
Expand All @@ -1904,7 +1874,7 @@ std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::get_range_for_bounds(
end_stmt->accept(this);
end = builder->CreateLoad(
#ifdef TI_LLVM_15
llvm_type(PrimitiveType::i32),
tlctx->get_data_type(PrimitiveType::i32),
#endif
llvm_val[end_stmt.get()]);
}
Expand Down
4 changes: 0 additions & 4 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,6 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

static std::string get_runtime_snode_name(SNode *snode);

llvm::Type *llvm_type(DataType dt);

llvm::Type *llvm_ptr_type(DataType dt);

void visit(Block *stmt_list) override;

void visit(AllocaStmt *stmt) override;
Expand Down
24 changes: 14 additions & 10 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ llvm::Value *TaskCodeGenLLVM::to_quant_fixed(llvm::Value *real,
// Compute int(real * (1.0 / scale) + 0.5)
auto compute_type = qfxt->get_compute_type();
auto s = builder->CreateFPCast(tlctx->get_constant(1.0 / qfxt->get_scale()),
llvm_type(compute_type));
auto input_real = builder->CreateFPCast(real, llvm_type(compute_type));
tlctx->get_data_type(compute_type));
auto input_real =
builder->CreateFPCast(real, tlctx->get_data_type(compute_type));
auto scaled = builder->CreateFMul(input_real, s);

// Add/minus the 0.5 offset for rounding
Expand All @@ -58,9 +59,11 @@ llvm::Value *TaskCodeGenLLVM::to_quant_fixed(llvm::Value *real,

auto qit = qfxt->get_digits_type()->as<QuantIntType>();
if (qit->get_is_signed()) {
return builder->CreateFPToSI(scaled, llvm_type(qit->get_compute_type()));
return builder->CreateFPToSI(scaled,
tlctx->get_data_type(qit->get_compute_type()));
} else {
return builder->CreateFPToUI(scaled, llvm_type(qit->get_compute_type()));
return builder->CreateFPToUI(scaled,
tlctx->get_data_type(qit->get_compute_type()));
}
}

Expand Down Expand Up @@ -143,7 +146,7 @@ llvm::Value *TaskCodeGenLLVM::quant_int_or_quant_fixed_to_bits(

void TaskCodeGenLLVM::visit(BitStructStoreStmt *stmt) {
auto bit_struct = stmt->get_bit_struct();
auto physical_type = llvm_type(bit_struct->get_physical_type());
auto physical_type = tlctx->get_data_type(bit_struct->get_physical_type());

int num_non_exponent_children = 0;
for (int i = 0; i < bit_struct->get_num_members(); i++) {
Expand Down Expand Up @@ -288,7 +291,7 @@ void TaskCodeGenLLVM::store_quant_floats_with_shared_exponents(
BitStructStoreStmt *stmt) {
// handle each exponent separately
auto bit_struct = stmt->get_bit_struct();
auto physical_type = llvm_type(bit_struct->get_physical_type());
auto physical_type = tlctx->get_data_type(bit_struct->get_physical_type());
auto physical_value = builder->CreateLoad(physical_type, llvm_val[stmt->ptr]);
// fuse all stores into a masked store
llvm::Value *masked_val = nullptr;
Expand Down Expand Up @@ -469,7 +472,8 @@ llvm::Value *TaskCodeGenLLVM::extract_quant_int(llvm::Value *physical_value,
else
step2 = builder->CreateLShr(step1, right);

return builder->CreateIntCast(step2, llvm_type(qit->get_compute_type()),
return builder->CreateIntCast(step2,
tlctx->get_data_type(qit->get_compute_type()),
qit->get_is_signed());
}

Expand All @@ -479,12 +483,12 @@ llvm::Value *TaskCodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits,
llvm::Value *cast = nullptr;
auto compute_type = qfxt->get_compute_type()->as<PrimitiveType>();
if (qfxt->get_is_signed()) {
cast = builder->CreateSIToFP(digits, llvm_type(compute_type));
cast = builder->CreateSIToFP(digits, tlctx->get_data_type(compute_type));
} else {
cast = builder->CreateUIToFP(digits, llvm_type(compute_type));
cast = builder->CreateUIToFP(digits, tlctx->get_data_type(compute_type));
}
llvm::Value *s = tlctx->get_constant(qfxt->get_scale());
s = builder->CreateFPCast(s, llvm_type(compute_type));
s = builder->CreateFPCast(s, tlctx->get_data_type(compute_type));
return builder->CreateFMul(cast, s);
}

Expand Down
32 changes: 15 additions & 17 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,38 +120,36 @@ TaichiLLVMContext::~TaichiLLVMContext() {

llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) {
auto ctx = get_this_thread_context();
if (dt->is_primitive(PrimitiveTypeID::i32)) {
return llvm::Type::getInt32Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::i8)) {
if (dt->is_primitive(PrimitiveTypeID::i8) ||
dt->is_primitive(PrimitiveTypeID::u8)) {
return llvm::Type::getInt8Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
} else if (dt->is_primitive(PrimitiveTypeID::i16) ||
dt->is_primitive(PrimitiveTypeID::u16)) {
return llvm::Type::getInt16Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
} else if (dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::u32)) {
return llvm::Type::getInt32Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::i64) ||
dt->is_primitive(PrimitiveTypeID::u64)) {
return llvm::Type::getInt64Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return llvm::Type::getInt1Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::f32)) {
return llvm::Type::getFloatTy(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return llvm::Type::getDoubleTy(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return llvm::Type::getInt8Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
return llvm::Type::getInt16Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::u32)) {
return llvm::Type::getInt32Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::u64)) {
return llvm::Type::getInt64Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
return llvm::Type::getHalfTy(*ctx);
} else if (dt->is<TensorType>()) {
TI_ASSERT_INFO(config_->real_matrix,
"Real matrix not enabled but got TensorType");
auto vectorty = dt->as<TensorType>();
auto dtype = this->get_data_type(vectorty->get_element_type());
return llvm::VectorType::get(dtype, vectorty->get_num_elements(),
auto tensor_type = dt->cast<TensorType>();
auto element_type = get_data_type(tensor_type->get_element_type());
return llvm::VectorType::get(element_type, tensor_type->get_num_elements(),
/*scalable=*/false);
} else {
TI_INFO(data_type_name(dt));
TI_NOT_IMPLEMENTED
TI_NOT_IMPLEMENTED;
}
}

Expand Down

0 comments on commit ffee072

Please sign in to comment.