Skip to content

Commit

Permalink
[Type] [refactor] Make PrimitiveTypeID a public enum (#1965)
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie authored Oct 16, 2020
1 parent 340f4b0 commit ceea597
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 37 deletions.
11 changes: 5 additions & 6 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,20 @@ TLANG_NAMESPACE_BEGIN
// manage its ownership more systematically.

// This part doesn't look good, but we will remove it soon anyway.
#define PER_TYPE(x) \
DataType PrimitiveType::x = \
DataType(TypeFactory::get_instance().get_primitive_type( \
PrimitiveType::primitive_type::x));
#define PER_TYPE(x) \
DataType PrimitiveType::x = DataType( \
TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::x));

#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) {
}

DataType PrimitiveType::get(PrimitiveType::primitive_type t) {
DataType PrimitiveType::get(PrimitiveTypeID t) {
if (false) {
}
#define PER_TYPE(x) else if (t == primitive_type::x) return PrimitiveType::x;
#define PER_TYPE(x) else if (t == PrimitiveTypeID::x) return PrimitiveType::x;
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE
else {
Expand Down
18 changes: 9 additions & 9 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

TLANG_NAMESPACE_BEGIN

enum class PrimitiveTypeID : int {
#define PER_TYPE(x) x,
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE
};

class Type {
public:
virtual std::string to_string() const = 0;
Expand Down Expand Up @@ -90,24 +96,18 @@ class DataType {

class PrimitiveType : public Type {
public:
enum class primitive_type : int {
#define PER_TYPE(x) x,
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE
};

#define PER_TYPE(x) static DataType x;
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

primitive_type type;
PrimitiveTypeID type;

PrimitiveType(primitive_type type) : type(type) {
PrimitiveType(PrimitiveTypeID type) : type(type) {
}

std::string to_string() const override;

static DataType get(primitive_type type);
static DataType get(PrimitiveTypeID type);
};

class PointerType : public Type {
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ TypeFactory &TypeFactory::get_instance() {
return *type_factory;
}

Type *TypeFactory::get_primitive_type(PrimitiveType::primitive_type id) {
Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
std::lock_guard<std::mutex> _(mut_);

if (primitive_types_.find(id) == primitive_types_.end()) {
Expand Down
5 changes: 2 additions & 3 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TypeFactory {
public:
static TypeFactory &get_instance();

Type *get_primitive_type(PrimitiveType::primitive_type id);
Type *get_primitive_type(PrimitiveTypeID id);

Type *get_vector_type(int num_elements, Type *element);

Expand All @@ -19,8 +19,7 @@ class TypeFactory {
private:
TypeFactory();

std::unordered_map<PrimitiveType::primitive_type, std::unique_ptr<Type>>
primitive_types_;
std::unordered_map<PrimitiveTypeID, std::unique_ptr<Type>> primitive_types_;

// TODO: use unordered map
std::map<std::pair<int, Type *>, std::unique_ptr<Type>> vector_types_;
Expand Down
6 changes: 2 additions & 4 deletions taichi/lang_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,9 @@ class TypePromotionMapping {
}

private:
std::map<
std::pair<PrimitiveType::primitive_type, PrimitiveType::primitive_type>,
PrimitiveType::primitive_type>
std::map<std::pair<PrimitiveTypeID, PrimitiveTypeID>, PrimitiveTypeID>
mapping;
static PrimitiveType::primitive_type to_primitive_type(const DataType d_) {
static PrimitiveTypeID to_primitive_type(const DataType d_) {
Type *d = d_.get_ptr();
if (d->is<PointerType>()) {
d = d->as<PointerType>()->get_pointee_type();
Expand Down
24 changes: 12 additions & 12 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,29 @@ inline DataType get_data_type() {
}

template <typename T>
inline PrimitiveType::primitive_type get_primitive_data_type() {
inline PrimitiveTypeID get_primitive_data_type() {
if (std::is_same<T, float32>()) {
return PrimitiveType::primitive_type::f32;
return PrimitiveTypeID::f32;
} else if (std::is_same<T, float64>()) {
return PrimitiveType::primitive_type::f64;
return PrimitiveTypeID::f64;
} else if (std::is_same<T, bool>()) {
return PrimitiveType::primitive_type::u1;
return PrimitiveTypeID::u1;
} else if (std::is_same<T, int8>()) {
return PrimitiveType::primitive_type::i8;
return PrimitiveTypeID::i8;
} else if (std::is_same<T, int16>()) {
return PrimitiveType::primitive_type::i16;
return PrimitiveTypeID::i16;
} else if (std::is_same<T, int32>()) {
return PrimitiveType::primitive_type::i32;
return PrimitiveTypeID::i32;
} else if (std::is_same<T, int64>()) {
return PrimitiveType::primitive_type::i64;
return PrimitiveTypeID::i64;
} else if (std::is_same<T, uint8>()) {
return PrimitiveType::primitive_type::u8;
return PrimitiveTypeID::u8;
} else if (std::is_same<T, uint16>()) {
return PrimitiveType::primitive_type::u16;
return PrimitiveTypeID::u16;
} else if (std::is_same<T, uint32>()) {
return PrimitiveType::primitive_type::u32;
return PrimitiveTypeID::u32;
} else if (std::is_same<T, uint64>()) {
return PrimitiveType::primitive_type::u64;
return PrimitiveTypeID::u64;
} else {
TI_NOT_IMPLEMENTED;
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ void export_lang(py::module &m) {
if (t.size() != 1)
throw std::runtime_error("Invalid state!");

DataType dt = PrimitiveType::get(
(PrimitiveType::primitive_type)(t[0].cast<std::size_t>()));
DataType dt =
PrimitiveType::get((PrimitiveTypeID)(t[0].cast<std::size_t>()));

return dt;
}));
Expand Down

0 comments on commit ceea597

Please sign in to comment.