Skip to content

Commit

Permalink
impl cast for tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
AD1024 committed Aug 25, 2022
1 parent 682e8de commit 76d71e0
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 55 deletions.
194 changes: 139 additions & 55 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,82 +344,166 @@ TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel,
void TaskCodeGenLLVM::visit(DecorationStmt *stmt) {
}

void TaskCodeGenLLVM::create_value_cast(
UnaryOpStmt *stmt,
std::function<llvm::Value *(llvm::Value *, llvm::Type *)> cast_fn,
DataType to_ty) {
if (!to_ty->is<TensorType>()) {
llvm_val[stmt] =
cast_fn(llvm_val[stmt->operand], tlctx->get_data_type(to_ty));
} else {
auto from_ty = stmt->operand->ret_type->cast<TensorType>();
TI_ASSERT_INFO(from_ty, "Cannot cast non-tensor type {} to {}",
from_ty->to_string(), to_ty->to_string());
auto tensor_type = to_ty->cast<TensorType>();
llvm::Value *vec = llvm::UndefValue::get(tlctx->get_data_type(tensor_type));
for (int i = 0; i < tensor_type->get_num_elements(); ++i) {
auto elem = builder->CreateExtractElement(llvm_val[stmt->operand], i);
auto cast_input =
cast_fn(elem, tlctx->get_data_type(tensor_type->get_element_type()));
vec = builder->CreateInsertElement(vec, cast_input, i);
}
llvm_val[stmt] = vec;
}
}

void TaskCodeGenLLVM::create_fp_trunc(
UnaryOpStmt *stmt,
std::function<llvm::Value *(llvm::Value *, llvm::Type *)> trunc_fn,
llvm::Type *to_ty,
bool is_tensor) {
if (!is_tensor) {
llvm_val[stmt] = trunc_fn(llvm_val[stmt->operand], to_ty);
} else {
auto from_ty = stmt->operand->ret_type->cast<TensorType>();
TI_ASSERT_INFO(from_ty,
"Cannot truncate non-tensor type {} to a tensor type",
from_ty->to_string());
llvm::Value *vec = llvm::UndefValue::get(llvm::VectorType::get(
to_ty, from_ty->get_num_elements(), /*scalable=*/false));
// This assumes cast does not change the number of
// elements in a tensor value (should be legit)
for (int i = 0; i < from_ty->get_num_elements(); ++i) {
auto elem = builder->CreateExtractElement(llvm_val[stmt->operand], i);
auto trunc_value = trunc_fn(elem, to_ty);
vec = builder->CreateInsertElement(vec, trunc_value, i);
}
llvm_val[stmt] = vec;
}
}

void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
auto input = llvm_val[stmt->operand];
auto input_type = input->getType();
auto op = stmt->op_type;
auto get_cast_op = [](auto from, auto to) {
llvm::CastInst::CastOps cast_op;
if (is_real(from) && is_integral(to)) {
cast_op = is_signed(to) ? llvm::Instruction::CastOps::FPToSI
: llvm::Instruction::CastOps::FPToUI;
} else if (is_integral(from) && is_real(to)) {
cast_op = is_signed(from) ? llvm::Instruction::CastOps::SIToFP
: llvm::Instruction::CastOps::UIToFP;
} else {
TI_P(data_type_name(from));
TI_P(data_type_name(to));
TI_NOT_IMPLEMENTED;
}
return cast_op;
};

#define UNARY_INTRINSIC(x) \
else if (op == UnaryOpType::x) { \
llvm_val[stmt] = \
builder->CreateIntrinsic(llvm::Intrinsic::x, {input_type}, {input}); \
}
if (stmt->op_type == UnaryOpType::cast_value) {
llvm::CastInst::CastOps cast_op;
// Suppress warning
llvm::CastInst::CastOps cast_op __attribute__((unused));
auto from = stmt->operand->ret_type;
auto to = stmt->cast_type;
TI_ASSERT_INFO(
from->is<TensorType>() == to->is<TensorType>(),
"Cannot cast between tensor type and non-tensor type: {} v.s. {}",
from->to_string(), to->to_string());
if (from == to) {
llvm_val[stmt] = llvm_val[stmt->operand];
} else if (from->is<TensorType>()) {
TI_ASSERT_INFO(to->is<TensorType>(),
"Only tensor to tensor cast is supported, {} provided",
to->to_string());
auto from_ty = from->cast<TensorType>()->get_element_type();
auto to_ty = to->cast<TensorType>()->get_element_type();
cast_op = get_cast_op(from_ty, to_ty);
auto type = tlctx->get_data_type(to->cast<TensorType>());
llvm::Value *vec = llvm::UndefValue::get(type);
for (int i = 0; i < from->cast<TensorType>()->get_num_elements(); ++i) {
auto elem = builder->CreateExtractElement(llvm_val[stmt->operand], i);
auto cast_input =
builder->CreateCast(cast_op, elem, tlctx->get_data_type(to_ty));
vec = builder->CreateInsertElement(vec, cast_input, i);
} else if (is_real(from) != is_real(to) ||
is_real_tensor(from) != is_real_tensor(to)) {
llvm::CastInst::CastOps cast_op;
if ((is_real(from) || is_real_tensor(from)) &&
(is_integral(to) || is_integral_tensor(to))) {
cast_op = (is_signed(to) || is_signed_tensor(to))
? llvm::Instruction::CastOps::FPToSI
: llvm::Instruction::CastOps::FPToUI;
} else if ((is_integral(from) || is_integral_tensor(from)) &&
(is_real(to) || is_real_tensor(to))) {
cast_op = (is_signed(from) || is_signed_tensor(from))
? llvm::Instruction::CastOps::SIToFP
: llvm::Instruction::CastOps::UIToFP;
} else {
TI_P(data_type_name(from));
TI_P(data_type_name(to));
TI_NOT_IMPLEMENTED;
}
llvm_val[stmt] = vec;
} else if (is_real(from) != is_real(to)) {
cast_op = get_cast_op(from, to);
auto cast_type = to->is_primitive(PrimitiveTypeID::f16)
? PrimitiveType::f32
: stmt->cast_type;

llvm_val[stmt] = builder->CreateCast(cast_op, llvm_val[stmt->operand],
tlctx->get_data_type(cast_type));

if (to->is_primitive(PrimitiveTypeID::f16)) {
llvm_val[stmt] = builder->CreateFPTrunc(
llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context));
bool use_f16 = to->is_primitive(PrimitiveTypeID::f16) ||
(to->is<TensorType>() &&
to->cast<TensorType>()->get_element_type()->is_primitive(
PrimitiveTypeID::f16));
auto cast_type = use_f16 ? (to->is<TensorType>()
? TypeFactory::create_tensor_type(
to->cast<TensorType>()->get_shape(),
PrimitiveType::f32)
: PrimitiveType::f32)
: stmt->cast_type;

auto cast_func = [this, cast_op](llvm::Value *value, llvm::Type *type) {
return this->builder->CreateCast(cast_op, value, type);
};
create_value_cast(stmt, cast_func, cast_type);

if (use_f16) {
auto trunc_func = [this](llvm::Value *value, llvm::Type *type) {
return this->builder->CreateFPTrunc(value, type);
};
create_fp_trunc(stmt, trunc_func, llvm::Type::getHalfTy(*llvm_context),
cast_type->is<TensorType>());
}
} else if (is_real(from) && is_real(to)) {
if (data_type_size(from) < data_type_size(to)) {
llvm_val[stmt] = builder->CreateFPExt(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else if ((is_real(from) || is_real_tensor(from)) &&
(is_real(to) || is_real_tensor(to))) {
auto t1 = from->is<TensorType>()
? from->cast<TensorType>()->get_element_type()
: from.operator->();
auto t2 = to->is<TensorType>()
? to->cast<TensorType>()->get_element_type()
: to.operator->();
if (data_type_size(t1) < data_type_size(t2)) {
auto cast_func = [this](llvm::Value *value, llvm::Type *type) {
return this->builder->CreateFPExt(value, type);
};
create_value_cast(stmt, cast_func, stmt->cast_type);
} else {
if (to->is_primitive(PrimitiveTypeID::f16)) {
llvm_val[stmt] = builder->CreateFPTrunc(
builder->CreateFPTrunc(llvm_val[stmt->operand],
llvm::Type::getFloatTy(*llvm_context)),
llvm::Type::getHalfTy(*llvm_context));
if (to->is_primitive(PrimitiveTypeID::f16) ||
(to->is<TensorType>() &&
to->cast<TensorType>()->get_element_type()->is_primitive(
PrimitiveTypeID::f16))) {
if (!to->is<TensorType>()) {
llvm_val[stmt] = builder->CreateFPTrunc(
builder->CreateFPTrunc(llvm_val[stmt->operand],
llvm::Type::getFloatTy(*llvm_context)),
llvm::Type::getHalfTy(*llvm_context));
} else {
auto tensor_type = to->cast<TensorType>();
llvm::Value *vec = llvm::UndefValue::get(tlctx->get_data_type(to));
for (int i = 0; i < tensor_type->get_num_elements(); ++i) {
auto elem = builder->CreateExtractElement(vec, i);
auto double_trunced = builder->CreateFPTrunc(
builder->CreateFPTrunc(elem,
llvm::Type::getFloatTy(*llvm_context)),
llvm::Type::getHalfTy(*llvm_context));
vec = builder->CreateInsertElement(vec, double_trunced, i);
}
llvm_val[stmt] = vec;
}
} else {
llvm_val[stmt] = builder->CreateFPTrunc(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
auto trunc_fn = [this](llvm::Value *value, llvm::Type *type) {
return this->builder->CreateFPTrunc(value, type);
};
auto cast_type =
stmt->cast_type->is<TensorType>()
? stmt->cast_type->cast<TensorType>()->get_element_type()
: stmt->cast_type.operator->();
create_fp_trunc(stmt, trunc_fn, tlctx->get_data_type(cast_type),
stmt->cast_type->is<TensorType>());
}
}
} else if (!is_real(from) && !is_real(to)) {
} else if (!(is_real(from) || is_real_tensor(from)) &&
!(is_real(to) || is_real_tensor(to))) {
llvm_val[stmt] = builder->CreateIntCast(llvm_val[stmt->operand],
llvm_type(to), is_signed(from));
}
Expand Down
11 changes: 11 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *node_meta,
SNode *snode);

void create_value_cast(
UnaryOpStmt *stmt,
std::function<llvm::Value *(llvm::Value *, llvm::Type *)> cast_fn,
DataType to_ty);

void create_fp_trunc(
UnaryOpStmt *stmt,
std::function<llvm::Value *(llvm::Value *, llvm::Type *)> trunc_fn,
llvm::Type *to_ty,
bool is_tensor);

std::unique_ptr<RuntimeObject> emit_struct_meta_object(SNode *snode);

llvm::Value *emit_struct_meta(SNode *snode);
Expand Down

0 comments on commit 76d71e0

Please sign in to comment.