Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] [refactor] Promote DataType to a class #1906

Merged
merged 13 commits into from
Oct 4, 2020
1 change: 1 addition & 0 deletions misc/prtags.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@
"error" : "Error messages",
"blender" : "Blender intergration",
"export" : "Exporting kernels",
"type" : "Type system",
"release" : "Release"
}
20 changes: 10 additions & 10 deletions python/taichi/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,32 @@ def is_taichi_class(rhs):

# Real types

float32 = taichi_lang_core.DataType.float32
float32 = taichi_lang_core.DataType_float32
f32 = float32
float64 = taichi_lang_core.DataType.float64
float64 = taichi_lang_core.DataType_float64
f64 = float64

real_types = [f32, f64, float]
real_type_ids = [id(t) for t in real_types]

# Integer types

int8 = taichi_lang_core.DataType.int8
int8 = taichi_lang_core.DataType_int8
i8 = int8
int16 = taichi_lang_core.DataType.int16
int16 = taichi_lang_core.DataType_int16
i16 = int16
int32 = taichi_lang_core.DataType.int32
int32 = taichi_lang_core.DataType_int32
i32 = int32
int64 = taichi_lang_core.DataType.int64
int64 = taichi_lang_core.DataType_int64
i64 = int64

uint8 = taichi_lang_core.DataType.uint8
uint8 = taichi_lang_core.DataType_uint8
u8 = uint8
uint16 = taichi_lang_core.DataType.uint16
uint16 = taichi_lang_core.DataType_uint16
u16 = uint16
uint32 = taichi_lang_core.DataType.uint32
uint32 = taichi_lang_core.DataType_uint32
u32 = uint32
uint64 = taichi_lang_core.DataType.uint64
uint64 = taichi_lang_core.DataType_uint64
u64 = uint64

integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int]
Expand Down
26 changes: 12 additions & 14 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,19 +257,17 @@ class CCTransformer : public IRVisitor {
}

static std::string _get_libc_function_name(std::string name, DataType dt) {
switch (dt) {
case DataType::i32:
return name;
case DataType::i64:
return "ll" + name;
case DataType::f32:
return name + "f";
case DataType::f64:
return name;
default:
TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend",
name, data_type_name(dt));
}
if (dt == DataType::i32)
return name;
else if (dt == DataType::i64)
return "ll" + name;
else if (dt == DataType::f32)
return name + "f";
else if (dt == DataType::f64)
return name;
else
TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend", name,
data_type_name(dt));
}

static std::string get_libc_function_name(std::string name, DataType dt) {
Expand Down Expand Up @@ -598,7 +596,7 @@ class CCTransformer : public IRVisitor {
void emit_header(std::string f, Args &&... args) {
line_appender_header.append(std::move(f), std::move(args)...);
}
};
}; // namespace cccp

std::unique_ptr<CCKernel> CCKernelGen::compile() {
auto program = kernel->program.cc_program.get();
Expand Down
38 changes: 17 additions & 21 deletions taichi/backends/metal/data_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,24 @@ TLANG_NAMESPACE_BEGIN
namespace metal {

MetalDataType to_metal_type(DataType dt) {
switch (dt) {
#define METAL_CASE(x) \
case DataType::x: \
return MetalDataType::x

METAL_CASE(f32);
METAL_CASE(f64);
METAL_CASE(i8);
METAL_CASE(i16);
METAL_CASE(i32);
METAL_CASE(i64);
METAL_CASE(u8);
METAL_CASE(u16);
METAL_CASE(u32);
METAL_CASE(u64);
METAL_CASE(unknown);
#undef METAL_CASE

default:
TI_NOT_IMPLEMENTED;
break;
#define METAL_CASE(x) else if (dt == DataType::x) return MetalDataType::x
if (false) {
}
METAL_CASE(f32);
METAL_CASE(f64);
METAL_CASE(i8);
METAL_CASE(i16);
METAL_CASE(i32);
METAL_CASE(i64);
METAL_CASE(u8);
METAL_CASE(u16);
METAL_CASE(u32);
METAL_CASE(u64);
METAL_CASE(unknown);
else {
TI_NOT_IMPLEMENTED;
}
#undef METAL_CASE
return MetalDataType::unknown;
}

Expand Down
39 changes: 16 additions & 23 deletions taichi/backends/opengl/opengl_data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@ namespace opengl {

inline std::string opengl_data_type_name(DataType dt) {
// https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)
switch (dt) {
case DataType::f32:
return "float";
case DataType::f64:
return "double";
case DataType::i32:
return "int";
case DataType::i64:
return "int64_t";
default:
TI_NOT_IMPLEMENTED;
break;
}
return "";
if (dt == DataType::f32)
return "float";
else if (dt == DataType::f64)
return "double";
else if (dt == DataType::i32)
return "int";
else if (dt == DataType::i64)
return "int64_t";
else
TI_NOT_IMPLEMENTED;
}

inline bool is_opengl_binary_op_infix(BinaryOpType type) {
Expand All @@ -36,15 +32,12 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) {
}

inline int opengl_data_address_shifter(DataType type) {
switch (type) {
case DataType::f32:
case DataType::i32:
return 2;
case DataType::f64:
case DataType::i64:
return 3;
default:
TI_NOT_IMPLEMENTED
if (type == DataType::f32 || type == DataType::i32)
return 2;
else if (type == DataType::f64 || type == DataType::i64) {
return 3;
} else {
TI_NOT_IMPLEMENTED
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't include these OFT nits in an already-error-prone PR, they increased review difficulty.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I agree off-the-topic changes should be mostly avoided, note that this change is not really off the topic. The old switch-case implementation no longer works after refactoring, and the modifications are necessary for the build to pass.

The same for other places.

}
}

Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
dest_ty = tlctx->get_data_type(stmt->ret_type.data_type);
auto dest_bits = dest_ty->getPrimitiveSizeInBits();
auto truncated = builder->CreateTrunc(
raw_arg, Type::getIntNTy(*llvm_context, dest_bits));
raw_arg, llvm::Type::getIntNTy(*llvm_context, dest_bits));
llvm_val[stmt] = builder->CreateBitCast(truncated, dest_ty);
}
}
Expand Down Expand Up @@ -1327,7 +1327,7 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {

// per-leaf-block for loop
auto loop_index =
create_entry_block_alloca(Type::getInt32Ty(*llvm_context));
create_entry_block_alloca(llvm::Type::getInt32Ty(*llvm_context));

llvm::Value *thread_idx = nullptr, *block_dim = nullptr;

Expand Down
135 changes: 78 additions & 57 deletions taichi/lang_util.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Definitions of utility functions and enums

#include "lang_util.h"
#include "taichi/lang_util.h"

#include "taichi/math/linalg.h"
#include "taichi/program/arch.h"
Expand Down Expand Up @@ -29,6 +29,37 @@ real get_cpu_frequency() {

real default_measurement_time = 1;

// Note: these primitive types should never be freed. They are supposed to live
// together with the process.
#define PER_TYPE(x) \
DataType DataType::x = \
DataType(new PrimitiveType(PrimitiveType::primitive_type::x));
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

DataType PrimitiveType::get(PrimitiveType::primitive_type t) {
if (false) {
}
#define PER_TYPE(x) else if (t == primitive_type::x) return DataType::x;
k-ye marked this conversation as resolved.
Show resolved Hide resolved
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE
else {
TI_NOT_IMPLEMENTED
}
}

DataType::operator std::size_t() const {
if (auto primitive = dynamic_cast<const PrimitiveType *>(ptr_)) {
return (std::size_t)primitive->type;
} else {
TI_NOT_IMPLEMENTED
}
}

std::string PrimitiveType::serialize() const {
return data_type_name(DataType(this));
}

real measure_cpe(std::function<void()> target,
int64 elements_per_call,
real time_second) {
Expand Down Expand Up @@ -65,30 +96,26 @@ real measure_cpe(std::function<void()> target,
}

std::string data_type_name(DataType t) {
switch (t) {
#define REGISTER_DATA_TYPE(i, j) \
case DataType::i: \
return #j;

REGISTER_DATA_TYPE(f16, float16);
REGISTER_DATA_TYPE(f32, float32);
REGISTER_DATA_TYPE(f64, float64);
REGISTER_DATA_TYPE(u1, int1);
REGISTER_DATA_TYPE(i8, int8);
REGISTER_DATA_TYPE(i16, int16);
REGISTER_DATA_TYPE(i32, int32);
REGISTER_DATA_TYPE(i64, int64);
REGISTER_DATA_TYPE(u8, uint8);
REGISTER_DATA_TYPE(u16, uint16);
REGISTER_DATA_TYPE(u32, uint32);
REGISTER_DATA_TYPE(u64, uint64);
REGISTER_DATA_TYPE(gen, generic);
REGISTER_DATA_TYPE(unknown, unknown);
#define REGISTER_DATA_TYPE(i, j) else if (t == DataType::i) return #j
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, switch-or-if nits can be put iapr.

if (false) {
}
REGISTER_DATA_TYPE(f16, float16);
REGISTER_DATA_TYPE(f32, float32);
REGISTER_DATA_TYPE(f64, float64);
REGISTER_DATA_TYPE(u1, int1);
REGISTER_DATA_TYPE(i8, int8);
REGISTER_DATA_TYPE(i16, int16);
REGISTER_DATA_TYPE(i32, int32);
REGISTER_DATA_TYPE(i64, int64);
REGISTER_DATA_TYPE(u8, uint8);
REGISTER_DATA_TYPE(u16, uint16);
REGISTER_DATA_TYPE(u32, uint32);
REGISTER_DATA_TYPE(u64, uint64);
REGISTER_DATA_TYPE(gen, generic);
REGISTER_DATA_TYPE(unknown, unknown);

#undef REGISTER_DATA_TYPE
default:
TI_NOT_IMPLEMENTED
}
else TI_NOT_IMPLEMENTED
}

std::string data_type_format(DataType dt) {
Expand All @@ -110,48 +137,42 @@ std::string data_type_format(DataType dt) {
}

int data_type_size(DataType t) {
switch (t) {
case DataType::f16:
return 2;
case DataType::gen:
return 0;
case DataType::unknown:
return -1;

#define REGISTER_DATA_TYPE(i, j) \
case DataType::i: \
return sizeof(j);

REGISTER_DATA_TYPE(f32, float32);
REGISTER_DATA_TYPE(f64, float64);
REGISTER_DATA_TYPE(i8, int8);
REGISTER_DATA_TYPE(i16, int16);
REGISTER_DATA_TYPE(i32, int32);
REGISTER_DATA_TYPE(i64, int64);
REGISTER_DATA_TYPE(u8, uint8);
REGISTER_DATA_TYPE(u16, uint16);
REGISTER_DATA_TYPE(u32, uint32);
REGISTER_DATA_TYPE(u64, uint64);
if (false) {
} else if (t == DataType::f16)
return 2;
else if (t == DataType::gen)
return 0;
else if (t == DataType::unknown)
return -1;

#define REGISTER_DATA_TYPE(i, j) else if (t == DataType::i) return sizeof(j)

REGISTER_DATA_TYPE(f32, float32);
REGISTER_DATA_TYPE(f64, float64);
REGISTER_DATA_TYPE(i8, int8);
REGISTER_DATA_TYPE(i16, int16);
REGISTER_DATA_TYPE(i32, int32);
REGISTER_DATA_TYPE(i64, int64);
REGISTER_DATA_TYPE(u8, uint8);
REGISTER_DATA_TYPE(u16, uint16);
REGISTER_DATA_TYPE(u32, uint32);
REGISTER_DATA_TYPE(u64, uint64);

#undef REGISTER_DATA_TYPE
default:
TI_NOT_IMPLEMENTED
else {
TI_NOT_IMPLEMENTED
}
}

std::string data_type_short_name(DataType t) {
switch (t) {
#define PER_TYPE(i) \
case DataType::i: \
return #i;

if (false) {
}
#define PER_TYPE(i) else if (t == DataType::i) return #i;
#include "taichi/inc/data_type.inc.h"

#undef PER_TYPE
default:
TI_NOT_IMPLEMENTED
}
}
else
TI_NOT_IMPLEMENTED
} // namespace lang

std::string snode_type_name(SNodeType t) {
switch (t) {
Expand Down
Loading