Skip to content

Commit

Permalink
fix cast to f16
Browse files Browse the repository at this point in the history
  • Loading branch information
AD1024 committed Aug 25, 2022
1 parent 76d71e0 commit b6efff2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
15 changes: 8 additions & 7 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,11 @@ void TaskCodeGenLLVM::create_fp_trunc(
UnaryOpStmt *stmt,
std::function<llvm::Value *(llvm::Value *, llvm::Type *)> trunc_fn,
llvm::Type *to_ty,
bool is_tensor) {
bool is_tensor,
bool trunc_self) {
if (!is_tensor) {
llvm_val[stmt] = trunc_fn(llvm_val[stmt->operand], to_ty);
llvm_val[stmt] =
trunc_fn(trunc_self ? llvm_val[stmt] : llvm_val[stmt->operand], to_ty);
} else {
auto from_ty = stmt->operand->ret_type->cast<TensorType>();
TI_ASSERT_INFO(from_ty,
Expand All @@ -384,7 +386,8 @@ void TaskCodeGenLLVM::create_fp_trunc(
// This assumes cast does not change the number of
// elements in a tensor value (should be legit)
for (int i = 0; i < from_ty->get_num_elements(); ++i) {
auto elem = builder->CreateExtractElement(llvm_val[stmt->operand], i);
auto elem = builder->CreateExtractElement(
trunc_self ? llvm_val[stmt] : llvm_val[stmt->operand], i);
auto trunc_value = trunc_fn(elem, to_ty);
vec = builder->CreateInsertElement(vec, trunc_value, i);
}
Expand All @@ -403,8 +406,7 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
builder->CreateIntrinsic(llvm::Intrinsic::x, {input_type}, {input}); \
}
if (stmt->op_type == UnaryOpType::cast_value) {
// Suppress warning
llvm::CastInst::CastOps cast_op __attribute__((unused));
llvm::CastInst::CastOps cast_op;
auto from = stmt->operand->ret_type;
auto to = stmt->cast_type;
TI_ASSERT_INFO(
Expand All @@ -415,7 +417,6 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
llvm_val[stmt] = llvm_val[stmt->operand];
} else if (is_real(from) != is_real(to) ||
is_real_tensor(from) != is_real_tensor(to)) {
llvm::CastInst::CastOps cast_op;
if ((is_real(from) || is_real_tensor(from)) &&
(is_integral(to) || is_integral_tensor(to))) {
cast_op = (is_signed(to) || is_signed_tensor(to))
Expand Down Expand Up @@ -452,7 +453,7 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
return this->builder->CreateFPTrunc(value, type);
};
create_fp_trunc(stmt, trunc_func, llvm::Type::getHalfTy(*llvm_context),
cast_type->is<TensorType>());
cast_type->is<TensorType>(), /*trunc_self=*/true);
}
} else if ((is_real(from) || is_real_tensor(from)) &&
(is_real(to) || is_real_tensor(to))) {
Expand Down
3 changes: 2 additions & 1 deletion taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
UnaryOpStmt *stmt,
std::function<llvm::Value *(llvm::Value *, llvm::Type *)> trunc_fn,
llvm::Type *to_ty,
bool is_tensor);
bool is_tensor,
bool trunc_self = false);

std::unique_ptr<RuntimeObject> emit_struct_meta_object(SNode *snode);

Expand Down

0 comments on commit b6efff2

Please sign in to comment.