Skip to content

Commit

Permalink
class Type
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Sep 29, 2020
1 parent 0c88cba commit 3ab42da
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 18 deletions.
4 changes: 2 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
dest_ty = tlctx->get_data_type(stmt->ret_type.data_type);
auto dest_bits = dest_ty->getPrimitiveSizeInBits();
auto truncated = builder->CreateTrunc(
raw_arg, Type::getIntNTy(*llvm_context, dest_bits));
raw_arg, llvm::Type::getIntNTy(*llvm_context, dest_bits));
llvm_val[stmt] = builder->CreateBitCast(truncated, dest_ty);
}
}
Expand Down Expand Up @@ -1327,7 +1327,7 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {

// per-leaf-block for loop
auto loop_index =
create_entry_block_alloca(Type::getInt32Ty(*llvm_context));
create_entry_block_alloca(llvm::Type::getInt32Ty(*llvm_context));

llvm::Value *thread_idx = nullptr, *block_dim = nullptr;

Expand Down
12 changes: 7 additions & 5 deletions taichi/lang_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace lang {

CompileConfig default_compile_config;

#define PER_TYPE(x) static DataTypeNode *x;
#define PER_TYPE(x) static Type *x;
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

Expand All @@ -36,13 +36,15 @@ real get_cpu_frequency() {

real default_measurement_time = 1;

// Note: these Type primitives will never be freed. They are supposed to live
// with the process.
#define PER_TYPE(x) \
DataType DataType::x = \
DataType(new PrimitiveTypeNode(PrimitiveTypeNode::primitive_type::x));
DataType(new PrimitiveType(PrimitiveType::primitive_type::x));
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

DataType PrimitiveTypeNode::get(primitive_type t) {
DataType PrimitiveType::get(primitive_type t) {
if (false) {
}
#define PER_TYPE(x) else if (t == primitive_type::x) return DataType::x;
Expand All @@ -54,14 +56,14 @@ DataType PrimitiveTypeNode::get(primitive_type t) {
}

DataType::operator std::size_t() const {
if (auto primitive = dynamic_cast<const PrimitiveTypeNode *>(ptr_)) {
if (auto primitive = dynamic_cast<const PrimitiveType *>(ptr_)) {
return (std::size_t)primitive->type;
} else {
TI_NOT_IMPLEMENTED
}
}

std::string PrimitiveTypeNode::serialize() const {
std::string PrimitiveType::serialize() const {
return data_type_name(DataType(this));
}

Expand Down
14 changes: 6 additions & 8 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ struct Context;

using FunctionType = std::function<void(Context &)>;

class DataTypeNode {
class Type {
public:
virtual std::string serialize() const = 0;
virtual ~DataTypeNode() {
virtual ~Type() {
}
};

Expand All @@ -34,7 +34,7 @@ class DataType {
DataType() : ptr_(unknown.ptr_) {
}

DataType(const DataTypeNode *ptr) : ptr_(ptr) {
DataType(const Type *ptr) : ptr_(ptr) {
}

bool operator==(const DataType &o) const {
Expand All @@ -52,10 +52,10 @@ class DataType {
};

private:
const DataTypeNode *ptr_;
const Type *ptr_;
};

class PrimitiveTypeNode : public DataTypeNode {
class PrimitiveType : public Type {
public:
enum class primitive_type : int {
#define PER_TYPE(x) x,
Expand All @@ -65,7 +65,7 @@ class PrimitiveTypeNode : public DataTypeNode {

primitive_type type;

PrimitiveTypeNode(primitive_type type) : type(type) {
PrimitiveType(primitive_type type) : type(type) {
}

std::string serialize() const override;
Expand Down Expand Up @@ -147,7 +147,6 @@ inline bool is_integral(DataType dt) {
}

inline bool is_signed(DataType dt) {
TI_P(dt.serialize());
TI_ASSERT(is_integral(dt));
return dt == DataType::i8 || dt == DataType::i16 || dt == DataType::i32 ||
dt == DataType::i64;
Expand All @@ -159,7 +158,6 @@ inline bool is_unsigned(DataType dt) {
}

inline DataType to_unsigned(DataType dt) {
TI_P(dt.serialize());
TI_ASSERT(is_signed(dt));
if (dt == DataType::i8)
return DataType::u8;
Expand Down
2 changes: 1 addition & 1 deletion taichi/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ std::unique_ptr<llvm::Module> TaichiLLVMContext::clone_runtime_module() {

auto patch_intrinsic = [&](std::string name, Intrinsic::ID intrin,
bool ret = true,
std::vector<Type *> types = {},
std::vector<llvm::Type *> types = {},
std::vector<llvm::Value *> extra_args = {}) {
auto func = runtime_module->getFunction(name);
TI_ERROR_UNLESS(func, "Function {} not found", name);
Expand Down
3 changes: 1 addition & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ void export_lang(py::module &m) {
#undef PER_EXTENSION
.export_values();

py::class_<DataTypeNode>(m, "DataTypeNode");
py::class_<DataType>(m, "DataType").def(py::self == py::self);

py::class_<CompileConfig>(m, "CompileConfig")
Expand Down Expand Up @@ -535,7 +534,7 @@ void export_lang(py::module &m) {
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

m.def("get_primitive_type_node", PrimitiveTypeNode::get);
m.def("get_primitive_type_node", PrimitiveType::get);

m.def("is_integral", is_integral);
m.def("is_signed", is_signed);
Expand Down

0 comments on commit 3ab42da

Please sign in to comment.