Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Use PrimitiveType::type instead of DataType::type #1926

Merged
merged 2 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions taichi/analysis/value_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ValueDiffLoopIndex : public IRVisitor {
}

void visit(ConstStmt *stmt) override {
if (stmt->val[lane].dt == DataType::i32) {
if (stmt->val[lane].dt == PrimitiveType::i32) {
results[stmt->instance_id] = DiffRange(true, 0, stmt->val[lane].val_i32);
} else {
results[stmt->instance_id] = DiffRange();
Expand Down Expand Up @@ -112,7 +112,7 @@ class FindDirectValueBaseAndOffset : public IRVisitor {

void visit(ConstStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
if (stmt->val[0].dt == DataType::i32) {
if (stmt->val[0].dt == PrimitiveType::i32) {
result = std::make_tuple(true, nullptr, stmt->val[0].val_i32);
}
}
Expand Down
8 changes: 4 additions & 4 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,13 @@ class CCTransformer : public IRVisitor {
}

static std::string _get_libc_function_name(std::string name, DataType dt) {
if (dt == DataType::i32)
if (dt == PrimitiveType::i32)
return name;
else if (dt == DataType::i64)
else if (dt == PrimitiveType::i64)
return "ll" + name;
else if (dt == DataType::f32)
else if (dt == PrimitiveType::f32)
return name + "f";
else if (dt == DataType::f64)
else if (dt == PrimitiveType::f64)
return name;
else
TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend", name,
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
llvm::Type::getInt8PtrTy(*llvm_context),
tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(DataType::i32);
auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(2), loop_var);
stmt->body->accept(this);
Expand Down
36 changes: 18 additions & 18 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {

auto value_type = tlctx->get_data_type(arg_stmt->ret_type.data_type);
auto value = llvm_val[arg_stmt];
if (arg_stmt->ret_type.data_type == DataType::f32) {
value_type = tlctx->get_data_type(DataType::f64);
if (arg_stmt->ret_type.data_type == PrimitiveType::f32) {
value_type = tlctx->get_data_type(PrimitiveType::f64);
value = builder->CreateFPExt(value, value_type);
}

Expand Down Expand Up @@ -162,43 +162,43 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {

#define UNARY_STD(x) \
else if (op == UnaryOpType::x) { \
if (input_taichi_type == DataType::f32) { \
if (input_taichi_type == PrimitiveType::f32) { \
llvm_val[stmt] = \
builder->CreateCall(get_runtime_function("__nv_" #x "f"), input); \
} else if (input_taichi_type == DataType::f64) { \
} else if (input_taichi_type == PrimitiveType::f64) { \
llvm_val[stmt] = \
builder->CreateCall(get_runtime_function("__nv_" #x), input); \
} else if (input_taichi_type == DataType::i32) { \
} else if (input_taichi_type == PrimitiveType::i32) { \
llvm_val[stmt] = builder->CreateCall(get_runtime_function(#x), input); \
} else { \
TI_NOT_IMPLEMENTED \
} \
}
if (op == UnaryOpType::abs) {
if (input_taichi_type == DataType::f32) {
if (input_taichi_type == PrimitiveType::f32) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_fabsf"), input);
} else if (input_taichi_type == DataType::f64) {
} else if (input_taichi_type == PrimitiveType::f64) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_fabs"), input);
} else if (input_taichi_type == DataType::i32) {
} else if (input_taichi_type == PrimitiveType::i32) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_abs"), input);
} else {
TI_NOT_IMPLEMENTED
}
} else if (op == UnaryOpType::sqrt) {
if (input_taichi_type == DataType::f32) {
if (input_taichi_type == PrimitiveType::f32) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_sqrtf"), input);
} else if (input_taichi_type == DataType::f64) {
} else if (input_taichi_type == PrimitiveType::f64) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_sqrt"), input);
} else {
TI_NOT_IMPLEMENTED
}
} else if (op == UnaryOpType::logic_not) {
if (input_taichi_type == DataType::i32) {
if (input_taichi_type == PrimitiveType::i32) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("logic_not_i32"), input);
} else {
Expand Down Expand Up @@ -236,11 +236,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest],
llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == DataType::f32) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f32) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest],
llvm_val[stmt->val], AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == DataType::f64) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f64) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest],
llvm_val[stmt->val], AtomicOrdering::SequentiallyConsistent);
Expand All @@ -253,11 +253,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest],
llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == DataType::f32) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f32) {
old_value =
builder->CreateCall(get_runtime_function("atomic_min_f32"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type.data_type == DataType::f64) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f64) {
old_value =
builder->CreateCall(get_runtime_function("atomic_min_f64"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
Expand All @@ -270,11 +270,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest],
llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == DataType::f32) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f32) {
old_value =
builder->CreateCall(get_runtime_function("atomic_max_f32"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type.data_type == DataType::f64) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f64) {
old_value =
builder->CreateCall(get_runtime_function("atomic_max_f64"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
Expand Down Expand Up @@ -334,7 +334,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
{llvm::PointerType::get(get_runtime_type("Context"), 0),
get_tls_buffer_type(), tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(DataType::i32);
auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(2), loop_var);
stmt->body->accept(this);
Expand Down
12 changes: 6 additions & 6 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class KernelCodegen : public IRVisitor {
}
} else if (opty == SNodeOpType::append) {
TI_ASSERT(is_dynamic);
TI_ASSERT(stmt->ret_type.data_type == DataType::i32);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
emit("{} = {}.append({});", result_var, parent, stmt->val->raw_name());
} else if (opty == SNodeOpType::length) {
TI_ASSERT(is_dynamic);
Expand Down Expand Up @@ -485,19 +485,19 @@ class KernelCodegen : public IRVisitor {
current_appender().push_indent();
}

if (dt == DataType::i32) {
if (dt == PrimitiveType::i32) {
emit(
"const auto {} = atomic_fetch_{}_explicit((device atomic_int*){}, "
"{}, "
"metal::memory_order_relaxed);",
stmt->raw_name(), op_name, stmt->dest->raw_name(), val_var);
} else if (dt == DataType::u32) {
} else if (dt == PrimitiveType::u32) {
emit(
"const auto {} = atomic_fetch_{}_explicit((device atomic_uint*){}, "
"{}, "
"metal::memory_order_relaxed);",
stmt->raw_name(), op_name, stmt->dest->raw_name(), val_var);
} else if (dt == DataType::f32) {
} else if (dt == PrimitiveType::f32) {
if (handle_float) {
emit("const float {} = fatomic_fetch_{}({}, {});", stmt->raw_name(),
op_name, stmt->dest->raw_name(), val_var);
Expand Down Expand Up @@ -624,7 +624,7 @@ class KernelCodegen : public IRVisitor {
if (std::holds_alternative<Stmt *>(entry)) {
auto *arg_stmt = std::get<Stmt *>(entry);
const auto dt = arg_stmt->element_type();
TI_ASSERT_INFO(dt == DataType::i32 || dt == DataType::f32,
TI_ASSERT_INFO(dt == PrimitiveType::i32 || dt == PrimitiveType::f32,
"print() only supports i32 or f32 scalars for now.");
emit("{}.pm_set_{}({}, {});", msg_var_name, data_type_short_name(dt),
i, arg_stmt->raw_name());
Expand Down Expand Up @@ -1037,7 +1037,7 @@ class KernelCodegen : public IRVisitor {
used_features()->sparse = true;
}

std::string inject_load_global_tmp(int offset, DataType dt = DataType::i32) {
std::string inject_load_global_tmp(int offset, DataType dt = PrimitiveType::i32) {
const auto vt = VectorType(/*width=*/1, dt);
auto gtmp = Stmt::make<GlobalTemporaryStmt>(offset, vt);
gtmp->accept(this);
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/data_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ TLANG_NAMESPACE_BEGIN
namespace metal {

MetalDataType to_metal_type(DataType dt) {
#define METAL_CASE(x) else if (dt == DataType::x) return MetalDataType::x
#define METAL_CASE(x) else if (dt == PrimitiveType::x) return MetalDataType::x
if (false) {
}
METAL_CASE(f32);
Expand Down
36 changes: 18 additions & 18 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,22 @@ class KernelGen : public IRVisitor {
// Note that the following two functions not only returns the corresponding
// data type, but also **records** the usage of `i64` and `f64`.
std::string opengl_data_type_short_name(DataType dt) {
if (dt == DataType::i64) {
if (dt == PrimitiveType::i64) {
if (!TI_OPENGL_REQUIRE(used, GL_ARB_gpu_shader_int64)) {
TI_ERROR(
"Extension GL_ARB_gpu_shader_int64 not supported on your OpenGL");
}
used.int64 = true;
}
if (dt == DataType::f64)
if (dt == PrimitiveType::f64)
used.float64 = true;
return data_type_short_name(dt);
}

std::string opengl_data_type_name(DataType dt) {
if (dt == DataType::i64)
if (dt == PrimitiveType::i64)
used.int64 = true;
if (dt == DataType::f64)
if (dt == PrimitiveType::f64)
used.float64 = true;
return opengl::opengl_data_type_name(dt);
}
Expand Down Expand Up @@ -360,7 +360,7 @@ class KernelGen : public IRVisitor {
}

} else if (stmt->op_type == SNodeOpType::is_active) {
TI_ASSERT(stmt->ret_type.data_type == DataType::i32);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
if (stmt->snode->type == SNodeType::dense ||
stmt->snode->type == SNodeType::root) {
emit("int {} = 1;", stmt->short_name());
Expand All @@ -373,7 +373,7 @@ class KernelGen : public IRVisitor {

} else if (stmt->op_type == SNodeOpType::append) {
TI_ASSERT(stmt->snode->type == SNodeType::dynamic);
TI_ASSERT(stmt->ret_type.data_type == DataType::i32);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
emit("int {} = atomicAdd(_data_i32_[{} >> 2], 1);", stmt->short_name(),
get_snode_meta_address(stmt->snode));
auto dt = stmt->val->element_type();
Expand All @@ -387,7 +387,7 @@ class KernelGen : public IRVisitor {

} else if (stmt->op_type == SNodeOpType::length) {
TI_ASSERT(stmt->snode->type == SNodeType::dynamic);
TI_ASSERT(stmt->ret_type.data_type == DataType::i32);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
emit("int {} = _data_i32_[{} >> 2];", stmt->short_name(),
get_snode_meta_address(stmt->snode));

Expand Down Expand Up @@ -479,12 +479,12 @@ class KernelGen : public IRVisitor {
emit("{} {} = {}({});", dt_name, stmt->short_name(),
opengl_data_type_name(stmt->cast_type), stmt->operand->short_name());
} else if (stmt->op_type == UnaryOpType::cast_bits) {
if (stmt->cast_type == DataType::f32 &&
stmt->operand->element_type() == DataType::i32) {
if (stmt->cast_type == PrimitiveType::f32 &&
stmt->operand->element_type() == PrimitiveType::i32) {
emit("{} {} = intBitsToFloat({});", dt_name, stmt->short_name(),
stmt->operand->short_name());
} else if (stmt->cast_type == DataType::i32 &&
stmt->operand->element_type() == DataType::f32) {
} else if (stmt->cast_type == PrimitiveType::i32 &&
stmt->operand->element_type() == PrimitiveType::f32) {
emit("{} {} = floatBitsToInt({});", dt_name, stmt->short_name(),
stmt->operand->short_name());
} else {
Expand Down Expand Up @@ -527,7 +527,7 @@ class KernelGen : public IRVisitor {
return;
} else if (bin->op_type == BinaryOpType::atan2) {
if (bin->element_type() ==
DataType::f64) { // don't know why no atan(double, double)
PrimitiveType::f64) { // don't know why no atan(double, double)
emit("{} {} = {}(atan(float({}), float({})));", dt_name, bin_name,
dt_name, lhs_name, rhs_name);
} else {
Expand Down Expand Up @@ -573,25 +573,25 @@ class KernelGen : public IRVisitor {
void visit(AtomicOpStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto dt = stmt->dest->element_type();
if (dt == DataType::i32 ||
if (dt == PrimitiveType::i32 ||
(TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_int64) &&
dt == DataType::i64) ||
dt == PrimitiveType::i64) ||
((stmt->op_type == AtomicOpType::add ||
stmt->op_type == AtomicOpType::sub) &&
((TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_float) &&
dt == DataType::f32) ||
dt == PrimitiveType::f32) ||
(TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_float64) &&
dt == DataType::f64)))) {
dt == PrimitiveType::f64)))) {
emit("{} {} = {}(_{}_{}_[{} >> {}], {});",
opengl_data_type_name(stmt->val->element_type()), stmt->short_name(),
opengl_atomic_op_type_cap_name(stmt->op_type),
ptr_signats.at(stmt->dest->id), opengl_data_type_short_name(dt),
stmt->dest->short_name(), opengl_data_address_shifter(dt),
stmt->val->short_name());
} else {
if (dt != DataType::f32) {
if (dt != PrimitiveType::f32) {
TI_ERROR(
"unsupported atomic operation for DataType::{}, "
"unsupported atomic operation for PrimitiveType::{}, "
"this may because your OpenGL is missing that extension, "
"see `glewinfo` for more details",
opengl_data_type_short_name(dt));
Expand Down
12 changes: 6 additions & 6 deletions taichi/backends/opengl/opengl_data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ namespace opengl {

inline std::string opengl_data_type_name(DataType dt) {
// https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)
if (dt == DataType::f32)
if (dt == PrimitiveType::f32)
return "float";
else if (dt == DataType::f64)
else if (dt == PrimitiveType::f64)
return "double";
else if (dt == DataType::i32)
else if (dt == PrimitiveType::i32)
return "int";
else if (dt == DataType::i64)
else if (dt == PrimitiveType::i64)
return "int64_t";
else
TI_NOT_IMPLEMENTED;
Expand All @@ -32,9 +32,9 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) {
}

inline int opengl_data_address_shifter(DataType type) {
if (type == DataType::f32 || type == DataType::i32)
if (type == PrimitiveType::f32 || type == PrimitiveType::i32)
return 2;
else if (type == DataType::f64 || type == DataType::i64) {
else if (type == PrimitiveType::f64 || type == PrimitiveType::i64) {
return 3;
} else {
TI_NOT_IMPLEMENTED
Expand Down
Loading