Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] [llvm] Unify llvm_type() and get_data_type() #5927

Merged
merged 2 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 25 additions & 55 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
}
}
} else if (!is_real(from) && !is_real(to)) {
llvm_val[stmt] = builder->CreateIntCast(llvm_val[stmt->operand],
llvm_type(to), is_signed(from));
llvm_val[stmt] = builder->CreateIntCast(
llvm_val[stmt->operand], tlctx->get_data_type(to), is_signed(from));
}
} else if (stmt->op_type == UnaryOpType::cast_bits) {
TI_ASSERT(data_type_size(stmt->ret_type) ==
Expand Down Expand Up @@ -621,7 +621,8 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
} else {
TI_NOT_IMPLEMENTED
}
llvm_val[stmt] = builder->CreateSExt(cmp, llvm_type(PrimitiveType::i32));
llvm_val[stmt] =
builder->CreateSExt(cmp, tlctx->get_data_type(PrimitiveType::i32));
} else {
// This branch contains atan2 and pow which use runtime.cpp function for
// **real** type. We don't have f16 support there so promoting to f32 is
Expand Down Expand Up @@ -681,48 +682,11 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
}
}

llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) {
if (dt->is_primitive(PrimitiveTypeID::i8) ||
dt->is_primitive(PrimitiveTypeID::u8)) {
return llvm::Type::getInt8Ty(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i16) ||
dt->is_primitive(PrimitiveTypeID::u16)) {
return llvm::Type::getInt16Ty(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::u32)) {
return llvm::Type::getInt32Ty(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i64) ||
dt->is_primitive(PrimitiveTypeID::u64)) {
return llvm::Type::getInt64Ty(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return llvm::Type::getInt1Ty(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f32)) {
return llvm::Type::getFloatTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return llvm::Type::getDoubleTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
return llvm::Type::getHalfTy(*llvm_context);
} else if (dt->is<TensorType>()) {
TI_ASSERT_INFO(kernel->program->config.real_matrix,
"Real matrix not enabled but got TensorType");
auto tensor_type = dt->cast<TensorType>();
auto element_type = llvm_type(tensor_type->get_element_type());
return llvm::VectorType::get(element_type, tensor_type->get_num_elements(),
/*scalable=*/false);
} else {
TI_NOT_IMPLEMENTED;
}
return nullptr;
}

llvm::Type *TaskCodeGenLLVM::llvm_ptr_type(DataType dt) {
return llvm::PointerType::get(llvm_type(dt), 0);
}

void TaskCodeGenLLVM::visit(TernaryOpStmt *stmt) {
TI_ASSERT(stmt->op_type == TernaryOpType::select);
llvm_val[stmt] = builder->CreateSelect(
builder->CreateTrunc(llvm_val[stmt->op1], llvm_type(PrimitiveType::u1)),
builder->CreateTrunc(llvm_val[stmt->op1],
tlctx->get_data_type(PrimitiveType::u1)),
llvm_val[stmt->op2], llvm_val[stmt->op3]);
}

Expand Down Expand Up @@ -1030,7 +994,7 @@ void TaskCodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
BasicBlock::Create(*llvm_context, "for_loop_test", func);

#ifdef TI_LLVM_15
auto loop_var_ty = llvm_type(PrimitiveType::i32);
auto loop_var_ty = tlctx->get_data_type(PrimitiveType::i32);
#endif

auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
Expand Down Expand Up @@ -1195,7 +1159,7 @@ void TaskCodeGenLLVM::visit(LocalLoadStmt *stmt) {
if (auto *alloc = llvm::dyn_cast<llvm::AllocaInst>(val))
ptr_ty = alloc->getAllocatedType();
if (!ptr_ty && stmt->src->element_type().is_pointer()) {
ptr_ty = llvm_type(stmt->src->element_type().ptr_removed());
ptr_ty = tlctx->get_data_type(stmt->src->element_type().ptr_removed());
}
TI_ASSERT(ptr_ty);
llvm_val[stmt] = builder->CreateLoad(ptr_ty, llvm_val[stmt->src]);
Expand Down Expand Up @@ -1298,12 +1262,14 @@ llvm::Value *TaskCodeGenLLVM::quant_type_atomic(AtomicOpStmt *stmt) {
if (auto qit = dst_type->cast<QuantIntType>()) {
return atomic_add_quant_int(
llvm_val[stmt->dest],
llvm_type(stmt->dest->as<GetChStmt>()->input_snode->physical_type), qit,
llvm_val[stmt->val], is_signed(stmt->val->ret_type));
tlctx->get_data_type(
stmt->dest->as<GetChStmt>()->input_snode->physical_type),
qit, llvm_val[stmt->val], is_signed(stmt->val->ret_type));
} else if (auto qfxt = dst_type->cast<QuantFixedType>()) {
return atomic_add_quant_fixed(
llvm_val[stmt->dest],
llvm_type(stmt->dest->as<GetChStmt>()->input_snode->physical_type),
tlctx->get_data_type(
stmt->dest->as<GetChStmt>()->input_snode->physical_type),
qfxt, llvm_val[stmt->val]);
} else {
return nullptr;
Expand Down Expand Up @@ -1464,11 +1430,13 @@ void TaskCodeGenLLVM::visit(GlobalStoreStmt *stmt) {
pointee_type->to_string());
}
if (auto qit = pointee_type->cast<QuantIntType>()) {
store_quant_int(llvm_val[stmt->dest], llvm_type(snode->physical_type),
qit, llvm_val[stmt->val], true);
store_quant_int(llvm_val[stmt->dest],
tlctx->get_data_type(snode->physical_type), qit,
llvm_val[stmt->val], true);
} else if (auto qfxt = pointee_type->cast<QuantFixedType>()) {
store_quant_fixed(llvm_val[stmt->dest], llvm_type(snode->physical_type),
qfxt, llvm_val[stmt->val], true);
store_quant_fixed(llvm_val[stmt->dest],
tlctx->get_data_type(snode->physical_type), qfxt,
llvm_val[stmt->val], true);
} else {
TI_NOT_IMPLEMENTED;
}
Expand All @@ -1489,7 +1457,8 @@ void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt,
if (ptr_type->is_bit_pointer()) {
auto val_type = ptr_type->get_pointee_type();
auto get_ch = stmt->src->as<GetChStmt>();
auto physical_type = llvm_type(get_ch->input_snode->physical_type);
auto physical_type =
tlctx->get_data_type(get_ch->input_snode->physical_type);
auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
auto physical_value = should_cache_as_read_only
? create_intrinsic_load(byte_ptr, physical_type)
Expand All @@ -1510,7 +1479,8 @@ void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt,
} else {
// Byte pointer case.
if (should_cache_as_read_only) {
llvm_val[stmt] = create_intrinsic_load(ptr, llvm_type(stmt->ret_type));
llvm_val[stmt] =
create_intrinsic_load(ptr, tlctx->get_data_type(stmt->ret_type));
} else {
llvm_val[stmt] =
builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), ptr);
Expand Down Expand Up @@ -1892,7 +1862,7 @@ std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::get_range_for_bounds(
begin_stmt->accept(this);
begin = builder->CreateLoad(
#ifdef TI_LLVM_15
llvm_type(PrimitiveType::i32),
tlctx->get_data_type(PrimitiveType::i32),
#endif
llvm_val[begin_stmt.get()]);
}
Expand All @@ -1904,7 +1874,7 @@ std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::get_range_for_bounds(
end_stmt->accept(this);
end = builder->CreateLoad(
#ifdef TI_LLVM_15
llvm_type(PrimitiveType::i32),
tlctx->get_data_type(PrimitiveType::i32),
#endif
llvm_val[end_stmt.get()]);
}
Expand Down
4 changes: 0 additions & 4 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,6 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

static std::string get_runtime_snode_name(SNode *snode);

llvm::Type *llvm_type(DataType dt);

llvm::Type *llvm_ptr_type(DataType dt);

void visit(Block *stmt_list) override;

void visit(AllocaStmt *stmt) override;
Expand Down
24 changes: 14 additions & 10 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ llvm::Value *TaskCodeGenLLVM::to_quant_fixed(llvm::Value *real,
// Compute int(real * (1.0 / scale) + 0.5)
auto compute_type = qfxt->get_compute_type();
auto s = builder->CreateFPCast(tlctx->get_constant(1.0 / qfxt->get_scale()),
llvm_type(compute_type));
auto input_real = builder->CreateFPCast(real, llvm_type(compute_type));
tlctx->get_data_type(compute_type));
auto input_real =
builder->CreateFPCast(real, tlctx->get_data_type(compute_type));
auto scaled = builder->CreateFMul(input_real, s);

// Add/minus the 0.5 offset for rounding
Expand All @@ -58,9 +59,11 @@ llvm::Value *TaskCodeGenLLVM::to_quant_fixed(llvm::Value *real,

auto qit = qfxt->get_digits_type()->as<QuantIntType>();
if (qit->get_is_signed()) {
return builder->CreateFPToSI(scaled, llvm_type(qit->get_compute_type()));
return builder->CreateFPToSI(scaled,
tlctx->get_data_type(qit->get_compute_type()));
} else {
return builder->CreateFPToUI(scaled, llvm_type(qit->get_compute_type()));
return builder->CreateFPToUI(scaled,
tlctx->get_data_type(qit->get_compute_type()));
}
}

Expand Down Expand Up @@ -143,7 +146,7 @@ llvm::Value *TaskCodeGenLLVM::quant_int_or_quant_fixed_to_bits(

void TaskCodeGenLLVM::visit(BitStructStoreStmt *stmt) {
auto bit_struct = stmt->get_bit_struct();
auto physical_type = llvm_type(bit_struct->get_physical_type());
auto physical_type = tlctx->get_data_type(bit_struct->get_physical_type());

int num_non_exponent_children = 0;
for (int i = 0; i < bit_struct->get_num_members(); i++) {
Expand Down Expand Up @@ -288,7 +291,7 @@ void TaskCodeGenLLVM::store_quant_floats_with_shared_exponents(
BitStructStoreStmt *stmt) {
// handle each exponent separately
auto bit_struct = stmt->get_bit_struct();
auto physical_type = llvm_type(bit_struct->get_physical_type());
auto physical_type = tlctx->get_data_type(bit_struct->get_physical_type());
auto physical_value = builder->CreateLoad(physical_type, llvm_val[stmt->ptr]);
// fuse all stores into a masked store
llvm::Value *masked_val = nullptr;
Expand Down Expand Up @@ -469,7 +472,8 @@ llvm::Value *TaskCodeGenLLVM::extract_quant_int(llvm::Value *physical_value,
else
step2 = builder->CreateLShr(step1, right);

return builder->CreateIntCast(step2, llvm_type(qit->get_compute_type()),
return builder->CreateIntCast(step2,
tlctx->get_data_type(qit->get_compute_type()),
qit->get_is_signed());
}

Expand All @@ -479,12 +483,12 @@ llvm::Value *TaskCodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits,
llvm::Value *cast = nullptr;
auto compute_type = qfxt->get_compute_type()->as<PrimitiveType>();
if (qfxt->get_is_signed()) {
cast = builder->CreateSIToFP(digits, llvm_type(compute_type));
cast = builder->CreateSIToFP(digits, tlctx->get_data_type(compute_type));
} else {
cast = builder->CreateUIToFP(digits, llvm_type(compute_type));
cast = builder->CreateUIToFP(digits, tlctx->get_data_type(compute_type));
}
llvm::Value *s = tlctx->get_constant(qfxt->get_scale());
s = builder->CreateFPCast(s, llvm_type(compute_type));
s = builder->CreateFPCast(s, tlctx->get_data_type(compute_type));
return builder->CreateFMul(cast, s);
}

Expand Down
32 changes: 15 additions & 17 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,38 +120,36 @@ TaichiLLVMContext::~TaichiLLVMContext() {

llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) {
auto ctx = get_this_thread_context();
if (dt->is_primitive(PrimitiveTypeID::i32)) {
return llvm::Type::getInt32Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::i8)) {
if (dt->is_primitive(PrimitiveTypeID::i8) ||
dt->is_primitive(PrimitiveTypeID::u8)) {
return llvm::Type::getInt8Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
} else if (dt->is_primitive(PrimitiveTypeID::i16) ||
dt->is_primitive(PrimitiveTypeID::u16)) {
return llvm::Type::getInt16Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
} else if (dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::u32)) {
return llvm::Type::getInt32Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::i64) ||
dt->is_primitive(PrimitiveTypeID::u64)) {
return llvm::Type::getInt64Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return llvm::Type::getInt1Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::f32)) {
return llvm::Type::getFloatTy(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return llvm::Type::getDoubleTy(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return llvm::Type::getInt8Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
return llvm::Type::getInt16Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::u32)) {
return llvm::Type::getInt32Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::u64)) {
return llvm::Type::getInt64Ty(*ctx);
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
return llvm::Type::getHalfTy(*ctx);
} else if (dt->is<TensorType>()) {
TI_ASSERT_INFO(config_->real_matrix,
"Real matrix not enabled but got TensorType");
auto vectorty = dt->as<TensorType>();
auto dtype = this->get_data_type(vectorty->get_element_type());
return llvm::VectorType::get(dtype, vectorty->get_num_elements(),
auto tensor_type = dt->cast<TensorType>();
auto element_type = get_data_type(tensor_type->get_element_type());
return llvm::VectorType::get(element_type, tensor_type->get_num_elements(),
/*scalable=*/false);
} else {
TI_INFO(data_type_name(dt));
TI_NOT_IMPLEMENTED
TI_NOT_IMPLEMENTED;
}
}

Expand Down